mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
76 Commits
feat/build
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1750c833ee | ||
|
|
85f0d8353a | ||
|
|
866563ad25 | ||
|
|
e79928a815 | ||
|
|
1771ed3bef | ||
|
|
550fa5a319 | ||
|
|
8528dffbf2 | ||
|
|
8fbf6a4b09 | ||
|
|
239148596c | ||
|
|
a880d73481 | ||
|
|
80bfd64ffa | ||
|
|
0076ad2a1a | ||
|
|
edb3d322f0 | ||
|
|
9381057079 | ||
|
|
f21a36ca37 | ||
|
|
ee5382a064 | ||
|
|
b80e5ea987 | ||
|
|
3d4fcfacb6 | ||
|
|
32eac6d52e | ||
|
|
9762f4cde7 | ||
|
|
76901ba22f | ||
|
|
23b65939f3 | ||
|
|
1c27eaac53 | ||
|
|
923b164794 | ||
|
|
e86ac21c43 | ||
|
|
94224be841 | ||
|
|
da4bdc7ab9 | ||
|
|
7176cecf25 | ||
|
|
f35210761c | ||
|
|
1ebcf85669 | ||
|
|
ab7c38bda7 | ||
|
|
0f67e45d05 | ||
|
|
b9ce37600e | ||
|
|
3921deaef1 | ||
|
|
f01f668674 | ||
|
|
f7a3491f91 | ||
|
|
cbff3b53d3 | ||
|
|
5b9a4c52c9 | ||
|
|
0ce1c90b55 | ||
|
|
d4c6eb9adc | ||
|
|
1bb91b53b7 | ||
|
|
a5f9c43a41 | ||
|
|
1240f38f75 | ||
|
|
f617f50f0b | ||
|
|
943a1df815 | ||
|
|
593001e0c8 | ||
|
|
e1db8234a3 | ||
|
|
282173be9d | ||
|
|
5d9a169e04 | ||
|
|
6fd1050457 | ||
|
|
02708bcd00 | ||
|
|
156d61fe5c | ||
|
|
5a29de0e0e | ||
|
|
e657472162 | ||
|
|
4d00e0f179 | ||
|
|
1d7282b5f3 | ||
|
|
e3591fcaa3 | ||
|
|
876dc32e17 | ||
|
|
616e29f5e4 | ||
|
|
280a98ad38 | ||
|
|
c7f2a7dd03 | ||
|
|
6d0e2063ec | ||
|
|
8b577ae194 | ||
|
|
d8f5f783ae | ||
|
|
82d22f3680 | ||
|
|
50622333d1 | ||
|
|
27af5782a9 | ||
|
|
522f932e67 | ||
|
|
a6124b06d5 | ||
|
|
ae660ea04f | ||
|
|
2479f3a1c4 | ||
|
|
8153306384 | ||
|
|
9c3d100a22 | ||
|
|
fc3bf6c154 | ||
|
|
e32d258a7e | ||
|
|
3e86544bfe |
@@ -2,7 +2,7 @@
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
@@ -17,18 +17,70 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
**Bots to watch for:**
|
||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
|
||||
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
|
||||
|
||||
**Where each reviewer posts:**
|
||||
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
|
||||
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
|
||||
- Human reviewers — can post in any source. Address ALL non-empty feedback.
|
||||
|
||||
### 3. PR conversation comments — REST
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
@@ -40,8 +92,8 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -61,7 +113,9 @@ kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
@@ -69,11 +123,88 @@ For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ re-check comments → fix new ones → push
|
||||
→ wait for CI → re-check comments after CI settles
|
||||
→ wait for CI (while addressing new comments) → fix failures → push
|
||||
→ re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
While CI runs, stay productive: run local tests, address remaining comments.
|
||||
### Polling for CI + new comments
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
|
||||
|
||||
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
|
||||
|
||||
**Polling loop — repeat every 30 seconds:**
|
||||
|
||||
1. Check CI status:
|
||||
```bash
|
||||
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
|
||||
```
|
||||
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
|
||||
|
||||
2. Check for merge conflicts:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
|
||||
```
|
||||
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
|
||||
|
||||
3. Check for new/changed comments (all three sources):
|
||||
|
||||
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
|
||||
- A new thread `id` appears that wasn't in the baseline (new thread), OR
|
||||
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
|
||||
|
||||
**Conversation comments:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
|
||||
|
||||
**Top-level reviews:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
|
||||
|
||||
4. **React in this precedence order (first match wins):**
|
||||
|
||||
| What happened | Action |
|
||||
|---|---|
|
||||
| Merge conflict detected | See "Resolving merge conflicts" below. |
|
||||
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
|
||||
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
|
||||
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
|
||||
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
|
||||
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
|
||||
|
||||
### Resolving merge conflicts
|
||||
|
||||
1. Identify the PR's target branch and remote:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
|
||||
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
|
||||
```
|
||||
|
||||
2. Pull the latest base branch with a 3-way merge:
|
||||
```bash
|
||||
git pull {base-remote} {base-branch} --no-rebase
|
||||
```
|
||||
|
||||
3. Resolve conflicting files, then verify no conflict markers remain:
|
||||
```bash
|
||||
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
|
||||
echo "Unresolved conflict markers found — resolve before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
4. Stage and push:
|
||||
```bash
|
||||
git add <conflicted-files>
|
||||
git commit -m "Resolve merge conflicts with {base-branch}"
|
||||
git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
@@ -17,6 +17,16 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -28,12 +38,14 @@ gh pr diff {N}
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
754
.claude/skills/pr-test/SKILL.md
Normal file
754
.claude/skills/pr-test/SKILL.md
Normal file
@@ -0,0 +1,754 @@
|
||||
---
|
||||
name: pr-test
|
||||
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
|
||||
user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
|
||||
|
||||
## Step 0: Resolve the target
|
||||
|
||||
```bash
|
||||
# If argument is a PR number, find its worktree
|
||||
gh pr view {N} --json headRefName --jq '.headRefName'
|
||||
# If argument is a path, use it directly
|
||||
```
|
||||
|
||||
Determine:
|
||||
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
|
||||
- `WORKTREE_PATH` — the worktree directory
|
||||
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
|
||||
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
|
||||
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
|
||||
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
|
||||
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
|
||||
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
|
||||
|
||||
Create the results directory:
|
||||
```bash
|
||||
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
|
||||
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
|
||||
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
|
||||
mkdir -p $RESULTS_DIR
|
||||
```
|
||||
|
||||
**Test user credentials** (for logging into the UI or verifying results manually):
|
||||
- Email: `test@test.com`
|
||||
- Password: `testtest123`
|
||||
|
||||
## Step 1: Understand the PR
|
||||
|
||||
Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
```markdown
|
||||
# Test Plan: PR #{N} — {title}
|
||||
|
||||
## Scenarios
|
||||
1. [Scenario name] — [what to verify]
|
||||
2. ...
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
|
||||
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
|
||||
|
||||
```bash
|
||||
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
|
||||
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3b. Configure copilot authentication
|
||||
|
||||
The copilot needs an LLM API to function. Two approaches (try subscription first):
|
||||
|
||||
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
|
||||
|
||||
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
|
||||
|
||||
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
|
||||
|
||||
```bash
|
||||
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
|
||||
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
**How it works:** The script reads the OAuth token from:
|
||||
- **macOS**: system keychain (`"Claude Code-credentials"`)
|
||||
- **Linux/WSL**: `~/.claude/.credentials.json`
|
||||
- **Windows**: `%APPDATA%/claude/.credentials.json`
|
||||
|
||||
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
|
||||
|
||||
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
|
||||
|
||||
#### Option 2: OpenRouter API key mode (fallback)
|
||||
|
||||
If subscription mode doesn't work, switch to API key mode using OpenRouter:
|
||||
|
||||
```bash
|
||||
# In $BACKEND_DIR/.env, ensure these are set:
|
||||
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
|
||||
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
|
||||
CHAT_BASE_URL=https://openrouter.ai/api/v1
|
||||
CHAT_USE_CLAUDE_AGENT_SDK=true
|
||||
```
|
||||
|
||||
Use `sed` to update these values:
|
||||
```bash
|
||||
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
|
||||
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
|
||||
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
|
||||
# Add or update CHAT_API_KEY and CHAT_BASE_URL
|
||||
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
|
||||
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3c. Stop conflicting containers
|
||||
|
||||
```bash
|
||||
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
|
||||
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
|
||||
docker stop "$name" 2>/dev/null
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
|
||||
|
||||
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
|
||||
```
|
||||
|
||||
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
|
||||
|
||||
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
|
||||
|
||||
### 3f. Wait for services to be ready
|
||||
|
||||
```bash
|
||||
# Poll until backend and frontend respond
|
||||
for i in $(seq 1 60); do
|
||||
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
|
||||
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
|
||||
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
|
||||
echo "Services ready"
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### 3h. Create test user and get auth token
|
||||
|
||||
```bash
|
||||
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
|
||||
|
||||
# Signup (idempotent — returns "User already registered" if exists)
|
||||
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}')
|
||||
|
||||
# If "Database error finding user", restart supabase-auth and retry
|
||||
if echo "$RESULT" | grep -q "Database error"; then
|
||||
docker restart supabase-auth && sleep 5
|
||||
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}'
|
||||
fi
|
||||
|
||||
# Get auth token
|
||||
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
|
||||
```
|
||||
|
||||
**Use this token for ALL API calls:**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
|
||||
| Service | Port | URL |
|
||||
|---------|------|-----|
|
||||
| Frontend | 3000 | http://localhost:3000 |
|
||||
| Backend REST | 8006 | http://localhost:8006 |
|
||||
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
|
||||
| Executor | 8002 | http://localhost:8002 |
|
||||
| Copilot Executor | 8008 | http://localhost:8008 |
|
||||
| WebSocket | 8001 | http://localhost:8001 |
|
||||
| Database Manager | 8005 | http://localhost:8005 |
|
||||
| Redis | 6379 | localhost:6379 |
|
||||
| RabbitMQ | 5672 | localhost:5672 |
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
|
||||
|
||||
# Example: Create an agent
|
||||
curl -s -X POST http://localhost:8006/api/graphs \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{...}' | jq .
|
||||
|
||||
# Example: Run an agent
|
||||
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"data": {...}}'
|
||||
|
||||
# Example: Get execution results
|
||||
curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
# Close any existing session
|
||||
agent-browser close 2>/dev/null || true
|
||||
|
||||
# Use --session-name to persist cookies across navigations
|
||||
# This means login only needs to happen once per test session
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
|
||||
|
||||
# Get interactive elements
|
||||
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
|
||||
|
||||
# Login
|
||||
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
|
||||
agent-browser --session-name pr-test fill {password_ref} "testtest123"
|
||||
agent-browser --session-name pr-test click {login_button_ref}
|
||||
sleep 5
|
||||
|
||||
# Dismiss cookie banner if present
|
||||
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
|
||||
|
||||
# Navigate — cookies are preserved so login persists
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
|
||||
# Take screenshot
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
|
||||
|
||||
# Interact with elements
|
||||
agent-browser --session-name pr-test fill {ref} "text"
|
||||
agent-browser --session-name pr-test press "Enter"
|
||||
agent-browser --session-name pr-test click {ref}
|
||||
agent-browser --session-name pr-test click 'text=Button Text'
|
||||
|
||||
# Read page content
|
||||
agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
```
|
||||
|
||||
**Key pages:**
|
||||
- `/copilot` — CoPilot chat (for testing copilot features)
|
||||
- `/build` — Agent builder (for testing block/node features)
|
||||
- `/build?flowID={id}` — Specific agent in builder
|
||||
- `/library` — Agent library (for testing listing/import features)
|
||||
- `/library/agents/{id}` — Agent detail with run history
|
||||
- `/marketplace` — Marketplace
|
||||
|
||||
### Checking logs
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
|
||||
# Executor (runs agent graphs)
|
||||
docker logs autogpt_platform-executor-1 2>&1 | tail -30
|
||||
|
||||
# Copilot executor (runs copilot chat sessions)
|
||||
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
|
||||
|
||||
# Frontend
|
||||
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
|
||||
|
||||
# Filter for errors
|
||||
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
|
||||
```
|
||||
|
||||
### Copilot chat testing
|
||||
|
||||
The copilot uses SSE streaming. To test via API:
|
||||
|
||||
```bash
|
||||
# Create a session
|
||||
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}' | jq -r '.id // .session_id // ""')
|
||||
|
||||
# Stream a message (SSE - will stream chunks)
|
||||
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message": "Hello, what can you help me with?"}' \
|
||||
--max-time 60 2>/dev/null | head -50
|
||||
```
|
||||
|
||||
Or test via browser (preferred for UI verification):
|
||||
```bash
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
|
||||
|
||||
For each screenshot:
|
||||
1. Use the `Read` tool to display the PNG file (Claude can read images)
|
||||
2. Write a 1-2 sentence explanation below it describing:
|
||||
- What page/state is being shown
|
||||
- What the screenshot proves (which test scenario it validates)
|
||||
- Any notable details visible in the UI
|
||||
|
||||
Format the output like this:
|
||||
|
||||
```markdown
|
||||
### Screenshot 1: {descriptive title}
|
||||
[Read the PNG file here]
|
||||
|
||||
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
## Step 7: Post test report as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
${EXPLANATION}
|
||||
"
|
||||
done
|
||||
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
**Cause:** Supabase auth service schema cache is stale after migration.
|
||||
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
|
||||
|
||||
### Problem: Copilot returns auth errors in subscription mode
|
||||
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
|
||||
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
|
||||
|
||||
### Problem: agent-browser can't find chromium
|
||||
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
|
||||
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
|
||||
|
||||
### Problem: agent-browser selector matches multiple elements
|
||||
**Cause:** `text=X` matches all elements containing that text.
|
||||
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
|
||||
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
### Problem: Docker uses cached layers with old code (PR changes not visible)
|
||||
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
|
||||
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
|
||||
|
||||
### Problem: `agent-browser open` loses login session
|
||||
**Cause:** Without session persistence, `agent-browser open` starts fresh.
|
||||
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
|
||||
|
||||
### Problem: Supabase auth returns "Database error querying schema"
|
||||
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
|
||||
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,8 +1,12 @@
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
114
.github/workflows/platform-backend-ci.yml
vendored
114
.github/workflows/platform-backend-ci.yml
vendored
@@ -27,10 +27,91 @@ defaults:
|
||||
working-directory: autogpt_platform/backend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run Linters
|
||||
run: poetry run lint --skip-pyright
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
type-check:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Run Pyright
|
||||
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -98,9 +179,9 @@ jobs:
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
@@ -158,22 +239,22 @@ jobs:
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
@@ -188,18 +269,13 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run pytest
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
@@ -211,6 +287,12 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
@@ -224,9 +306,3 @@ jobs:
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -294,7 +294,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
path: autogpt_platform/frontend/playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
@@ -303,7 +303,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
path: autogpt_platform/frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
|
||||
@@ -53,18 +53,40 @@ AutoGPT Platform is a monorepo containing:
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
version = "7.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.0"
|
||||
version = "0.15.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -61,12 +61,13 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
@@ -75,6 +76,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
@@ -84,6 +86,30 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
@@ -121,19 +121,21 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -18,15 +18,20 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -91,18 +96,6 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -473,12 +466,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -486,28 +479,19 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -515,16 +499,7 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -597,11 +572,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -639,15 +614,18 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -7,6 +7,8 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -132,3 +134,40 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -4,14 +4,12 @@ from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -24,6 +22,7 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
@@ -271,7 +270,7 @@ async def _build_cached_search_results(
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
block_results, block_total, integration_total = await _text_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
@@ -383,117 +382,75 @@ def _collect_block_results(
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
async def _text_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
Search blocks using in-memory text matching over the block registry.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
All blocks are already loaded in memory, so this is fast and reliable
|
||||
regardless of whether OpenAI embeddings are available.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
return results, 0, 0
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
all_results, _, _ = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
for item in all_results:
|
||||
block_info = item.item
|
||||
assert isinstance(block_info, BlockInfo)
|
||||
name = split_camelcase(block_info.name).lower()
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Build rich description including input field descriptions,
|
||||
# matching the searchable text that the embedding pipeline uses
|
||||
desc_parts = [block_info.description or ""]
|
||||
block_cls = all_blocks.get(block_info.id)
|
||||
if block_cls is not None:
|
||||
block: AnyBlockSchema = block_cls()
|
||||
desc_parts += [
|
||||
f"{f}: {info.description}"
|
||||
for f, info in block.input_schema.model_fields.items()
|
||||
if info.description
|
||||
]
|
||||
description = " ".join(desc_parts).lower()
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
if block_cls is not None and _matches_llm_model(
|
||||
block_cls().input_schema, normalized_query
|
||||
):
|
||||
score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=item.filter_type,
|
||||
score=score + BLOCK_SCORE_BOOST,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
block_count = sum(1 for r in results if r.filter_type == "blocks")
|
||||
integration_count = sum(1 for r in results if r.filter_type == "integrations")
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ from backend.copilot.tools.models import (
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -895,36 +894,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -400,62 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
@@ -34,6 +34,7 @@ from backend.data.model import (
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
@@ -138,6 +139,18 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -204,15 +217,7 @@ async def callback(
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
@@ -222,16 +227,7 @@ async def list_credentials(
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -245,16 +241,7 @@ async def list_credentials_by_provider(
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -267,18 +254,21 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
return credential
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -288,16 +278,22 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
return credentials
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -332,15 +328,19 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
detail="Credentials not found",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .db import get_library_agent_by_graph_id, update_library_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent."""
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
):
|
||||
return existing
|
||||
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
)
|
||||
if deleted_agent and (deleted_agent.isDeleted or deleted_agent.isArchived):
|
||||
return await update_library_agent(
|
||||
deleted_agent.id,
|
||||
user_id,
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Race condition: concurrent request created the row between our
|
||||
# check and create. Re-read instead of crashing.
|
||||
existing = await get_library_agent_by_graph_id(
|
||||
user_id, graph_model.id, graph_model.version
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
raise # Shouldn't happen, but don't swallow unexpected errors
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -0,0 +1,71 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_archived_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
archived_agent = MagicMock(id="library-agent-id", isDeleted=False, isArchived=True)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=archived_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_restores_deleted_agent() -> None:
|
||||
graph_model = MagicMock(id="graph-id", version=2)
|
||||
deleted_agent = MagicMock(id="library-agent-id", isDeleted=True, isArchived=False)
|
||||
restored_agent = MagicMock(name="LibraryAgentModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.get_library_agent_by_graph_id",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.update_library_agent",
|
||||
new=AsyncMock(return_value=restored_agent),
|
||||
) as mock_update,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=deleted_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is restored_agent
|
||||
mock_update.assert_awaited_once_with(
|
||||
"library-agent-id",
|
||||
"user-id",
|
||||
is_deleted=False,
|
||||
is_archived=False,
|
||||
)
|
||||
mock_prisma.return_value.create.assert_not_called()
|
||||
@@ -336,12 +336,15 @@ async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -582,7 +585,9 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -818,92 +823,38 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -150,8 +150,13 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -224,3 +229,94 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
@@ -9,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.graph import (
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
@@ -104,7 +104,8 @@ async def get_store_agents(
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
# Fall through to direct DB search if hybrid returned nothing
|
||||
if search_used_hybrid and agents:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
@@ -130,52 +131,20 @@ async def get_store_agents(
|
||||
)
|
||||
continue
|
||||
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
if not search_used_hybrid or not agents:
|
||||
# Fallback path: direct DB query with optional tsvector search.
|
||||
# This mirrors the original pre-hybrid-search implementation.
|
||||
store_agents, total = await _fallback_store_agent_search(
|
||||
search_query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in db_agents:
|
||||
try:
|
||||
store_agents.append(store_model.StoreAgent.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from db: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
@@ -195,6 +164,126 @@ async def get_store_agents(
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def _fallback_store_agent_search(
|
||||
*,
|
||||
search_query: str | None,
|
||||
featured: bool,
|
||||
creators: list[str] | None,
|
||||
category: str | None,
|
||||
sorted_by: StoreAgentsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[list[store_model.StoreAgent], int]:
|
||||
"""Direct DB search fallback when hybrid search is unavailable or empty.
|
||||
|
||||
Uses ad-hoc to_tsvector/plainto_tsquery with ts_rank_cd for text search,
|
||||
matching the quality of the original pre-hybrid-search implementation.
|
||||
Falls back to simple listing when no search query is provided.
|
||||
"""
|
||||
if not search_query:
|
||||
# No search query — use Prisma for simple filtered listing
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
return [store_model.StoreAgent.from_db(a) for a in db_agents], total
|
||||
|
||||
# Text search using ad-hoc tsvector on StoreAgent view fields
|
||||
params: list[Any] = [search_query]
|
||||
filters = ["sa.is_available = true"]
|
||||
param_idx = 2
|
||||
|
||||
if featured:
|
||||
filters.append("sa.featured = true")
|
||||
if creators:
|
||||
params.append(creators)
|
||||
filters.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
if category:
|
||||
params.append(category)
|
||||
filters.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_sql = " AND ".join(filters)
|
||||
|
||||
params.extend([page_size, (page - 1) * page_size])
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
offset_param = f"${param_idx}"
|
||||
|
||||
sql = f"""
|
||||
WITH ranked AS (
|
||||
SELECT sa.*,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
),
|
||||
plainto_tsquery('english', $1)
|
||||
) AS rank,
|
||||
COUNT(*) OVER () AS total_count
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_sql}
|
||||
AND to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
) @@ plainto_tsquery('english', $1)
|
||||
)
|
||||
SELECT * FROM ranked
|
||||
ORDER BY rank DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
store_agents = []
|
||||
for row in results:
|
||||
try:
|
||||
store_agents.append(
|
||||
store_model.StoreAgent(
|
||||
slug=row["slug"],
|
||||
agent_name=row["agent_name"],
|
||||
agent_image=row["agent_image"][0] if row["agent_image"] else "",
|
||||
creator=row["creator_username"] or "Needs Profile",
|
||||
creator_avatar=row["creator_avatar"] or "",
|
||||
sub_heading=row["sub_heading"],
|
||||
description=row["description"],
|
||||
runs=row["runs"],
|
||||
rating=row["rating"],
|
||||
agent_graph_id=row.get("graph_id", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from fallback search: {e}")
|
||||
continue
|
||||
|
||||
return store_agents, total
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
@@ -302,6 +391,11 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -322,6 +416,57 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
@@ -1139,16 +1284,21 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
# Generate embedding for approved listing (best-effort)
|
||||
try:
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
except Exception as emb_err:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for listing "
|
||||
f"{store_listing_version_id}: {emb_err}"
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": submission.storeListingId},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth
|
||||
@@ -259,21 +258,18 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
) -> fastapi.responses.Response:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content=backend.util.json.dumps(graph_data),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file_name}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -178,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
@@ -598,6 +592,11 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -612,10 +611,20 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -971,14 +980,16 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -988,6 +999,7 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -248,7 +249,9 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -211,13 +210,22 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -267,12 +275,10 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -312,11 +318,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -527,8 +528,11 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Shared configuration for all AgentMail blocks.
|
||||
"""
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="agent_mail",
|
||||
title="Mock AgentMail API Key",
|
||||
api_key=SecretStr("mock-agentmail-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
|
||||
"""Create an AsyncAgentMail client from credentials."""
|
||||
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
AgentMail Attachment blocks — download file attachments from messages and threads.
|
||||
|
||||
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
|
||||
To send attachments, include them in the attachments parameter when using
|
||||
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
|
||||
|
||||
To download, first get the attachment_id from a message's attachments array,
|
||||
then use these blocks to retrieve the file content as base64.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailGetMessageAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a specific email message.
|
||||
|
||||
Retrieves the raw file content and returns it as base64-encoded data.
|
||||
First get the attachment_id from a message object's attachments array,
|
||||
then use this block to download the file.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID containing the attachment"
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from the message's attachments array)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
|
||||
description="Download a file attachment from an email message. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
message_id=message_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
message_id=input_data.message_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetThreadAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a conversation thread.
|
||||
|
||||
Same as GetMessageAttachment but looks up by thread ID instead of
|
||||
message ID. Useful when you know the thread but not the specific
|
||||
message containing the attachment.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID containing the attachment")
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from a message's attachments array within the thread)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
|
||||
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
thread_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
thread_id=thread_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
thread_id=input_data.thread_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
|
||||
|
||||
A Draft is an unsent message that can be reviewed, edited, and sent later.
|
||||
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
|
||||
and complex multi-step email composition workflows.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateDraftBlock(Block):
|
||||
"""
|
||||
Create a draft email in an AgentMail inbox for review or scheduled sending.
|
||||
|
||||
Drafts let agents prepare emails without sending immediately. Use send_at
|
||||
to schedule automatic sending at a future time (ISO 8601 format).
|
||||
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
|
||||
deleting the draft.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to create the draft in"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line", default="")
|
||||
text: str = SchemaField(description="Plain text body of the draft", default="")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the draft", default="", advanced=True
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
in_reply_to: str = SchemaField(
|
||||
description="Message ID this draft replies to, for threading follow-up drafts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(
|
||||
description="Unique identifier of the created draft"
|
||||
)
|
||||
send_status: str = SchemaField(
|
||||
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
|
||||
default="",
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete draft object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
|
||||
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "mock-draft-id"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "mock-draft-id",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.create(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.in_reply_to:
|
||||
params["in_reply_to"] = input_data.in_reply_to
|
||||
if input_data.send_at:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetDraftBlock(Block):
|
||||
"""
|
||||
Retrieve a specific draft from an AgentMail inbox.
|
||||
|
||||
Returns the draft contents including recipients, subject, body, and
|
||||
scheduled send status. Use this to review a draft before approving it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="Unique identifier of the draft")
|
||||
subject: str = SchemaField(description="Draft subject line", default="")
|
||||
send_status: str = SchemaField(
|
||||
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
|
||||
default="",
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Scheduled send time (ISO 8601) if set", default=""
|
||||
)
|
||||
result: dict = SchemaField(description="Complete draft object with all fields")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
|
||||
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("subject", ""),
|
||||
("send_status", ""),
|
||||
("send_at", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"subject": "",
|
||||
"send_status": "",
|
||||
"send_at": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
draft = await self.get_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "subject", draft.subject or ""
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "send_at", draft.send_at or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Use labels=['scheduled'] to find all drafts queued for future sending.
|
||||
Useful for building approval dashboards or monitoring pending outreach.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list drafts from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects with subject, recipients, send_status, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
|
||||
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_drafts(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateDraftBlock(Block):
|
||||
"""
|
||||
Update an existing draft's content, recipients, or scheduled send time.
|
||||
|
||||
Use this to reschedule a draft (change send_at), modify recipients,
|
||||
or edit the subject/body before sending. To cancel a scheduled send,
|
||||
delete the draft instead.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to update")
|
||||
to: Optional[list[str]] = SchemaField(
|
||||
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
subject: Optional[str] = SchemaField(
|
||||
description="Updated subject line. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="Updated plain text body. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
html: Optional[str] = SchemaField(
|
||||
description="Updated HTML body. Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
send_at: Optional[str] = SchemaField(
|
||||
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="The updated draft ID")
|
||||
send_status: str = SchemaField(description="Updated send status", default="")
|
||||
result: dict = SchemaField(description="Complete updated draft object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="351f6e51-695a-421a-9032-46a587b10336",
|
||||
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.update(
|
||||
inbox_id=inbox_id, draft_id=draft_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.to is not None:
|
||||
params["to"] = input_data.to
|
||||
if input_data.subject is not None:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text is not None:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html is not None:
|
||||
params["html"] = input_data.html
|
||||
if input_data.send_at is not None:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.update_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id, **params
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailSendDraftBlock(Block):
|
||||
"""
|
||||
Send a draft immediately, converting it into a delivered message.
|
||||
|
||||
The draft is deleted after successful sending and becomes a regular
|
||||
message with a message_id. Use this for human-in-the-loop approval
|
||||
workflows: agent creates draft, human reviews, then this block sends it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to send now")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID of the now-sent email (draft is deleted)"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID the sent message belongs to"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete sent message object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37c39e83-475d-4b3d-843a-d923d001b85a",
|
||||
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_draft": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-msg-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.send_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteDraftBlock(Block):
|
||||
"""
|
||||
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
|
||||
|
||||
If the draft was scheduled with send_at, deleting it cancels the
|
||||
scheduled delivery. This is the way to cancel a scheduled email.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(
|
||||
description="Draft ID to delete (also cancels scheduled sends)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the draft was successfully deleted/cancelled"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
|
||||
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_draft": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across every inbox in your organization.
|
||||
|
||||
Returns drafts from all inboxes in one query. Perfect for building
|
||||
a central approval dashboard where a human supervisor can review
|
||||
and approve any draft created by any agent.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
|
||||
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_drafts(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.drafts.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_org_drafts(credentials, **params)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
|
||||
|
||||
An Inbox is a fully programmable email account for AI agents. Each inbox gets
|
||||
a unique email address and can send, receive, and manage emails via the
|
||||
AgentMail API. You can create thousands of inboxes on demand.
|
||||
"""
|
||||
|
||||
from agentmail.inboxes.types import CreateInboxRequest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox for an AI agent via AgentMail.
|
||||
|
||||
Each inbox gets a unique email address (e.g. username@agentmail.to).
|
||||
If username and domain are not provided, AgentMail auto-generates them.
|
||||
Use custom domains by specifying the domain field.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier for the created inbox (also the email address)"
|
||||
)
|
||||
email_address: str = SchemaField(
|
||||
description="Full email address of the inbox (e.g. support@agentmail.to)"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
|
||||
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_inbox(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.create(request=CreateInboxRequest(**params))
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_inbox(credentials, **params)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing AgentMail inbox by its ID or email address.
|
||||
|
||||
Returns the inbox metadata including email address, display name, and
|
||||
configuration. Use this to check if an inbox exists or get its properties.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field", default=""
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
|
||||
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("email_address", "test-inbox"),
|
||||
("display_name", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.get(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.get_inbox(credentials, input_data.inbox_id)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "display_name", inbox.display_name or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListInboxesBlock(Block):
|
||||
"""
|
||||
List all email inboxes in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all inboxes with their metadata.
|
||||
Use page_token for pagination when you have many inboxes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page of results",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
|
||||
)
|
||||
count: int = SchemaField(
|
||||
description="Total number of inboxes in your organization"
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token to pass as page_token to get the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
|
||||
description="List all email inboxes in your AgentMail organization with pagination support.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_inboxes(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_inboxes(credentials, **params)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", (c if (c := response.count) is not None else len(inboxes))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateInboxBlock(Block):
|
||||
"""
|
||||
Update the display name of an existing AgentMail inbox.
|
||||
|
||||
Changes the friendly name shown in the 'From' field when emails are sent
|
||||
from this inbox. The email address itself cannot be changed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="New display name for the inbox (e.g. 'Customer Support Bot')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="The updated inbox ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
|
||||
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "Updated",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.update(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.update_inbox(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
display_name=input_data.display_name,
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxBlock(Block):
|
||||
"""
|
||||
Permanently delete an AgentMail inbox and all its data.
|
||||
|
||||
This removes the inbox, all its messages, threads, and drafts.
|
||||
This action cannot be undone. The email address will no longer
|
||||
receive or send emails.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to permanently delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the inbox was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade970ae-8428-4a7b-9278-b52054dbf535",
|
||||
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_inbox": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.delete(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_inbox(credentials, input_data.inbox_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
AgentMail List blocks — manage allow/block lists for email filtering.
|
||||
|
||||
Lists let you control which email addresses and domains your agents can
|
||||
send to or receive from. There are four list types based on two dimensions:
|
||||
direction (send/receive) and type (allow/block).
|
||||
|
||||
- receive + allow: Only accept emails from these addresses/domains
|
||||
- receive + block: Reject emails from these addresses/domains
|
||||
- send + allow: Only send emails to these addresses/domains
|
||||
- send + block: Prevent sending emails to these addresses/domains
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class ListDirection(str, Enum):
|
||||
SEND = "send"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListType(str, Enum):
|
||||
ALLOW = "allow"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class AgentMailListEntriesBlock(Block):
|
||||
"""
|
||||
List all entries in an AgentMail allow/block list.
|
||||
|
||||
Retrieves email addresses and domains that are currently allowed
|
||||
or blocked for sending or receiving. Use direction and list_type
|
||||
to select which of the four lists to query.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to return per page",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entries: list[dict] = SchemaField(
|
||||
description="List of entries, each with an email address or domain"
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="01489100-35da-45aa-8a01-9540ba0e9a21",
|
||||
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
},
|
||||
test_output=[
|
||||
("entries", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_entries": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"entries": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_entries(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.list(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_entries(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
entries = [e.model_dump() for e in response.entries]
|
||||
|
||||
yield "entries", entries
|
||||
yield "count", (c if (c := response.count) is not None else len(entries))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreateListEntryBlock(Block):
|
||||
"""
|
||||
Add an email address or domain to an AgentMail allow/block list.
|
||||
|
||||
Entries can be full email addresses (e.g. 'partner@example.com') or
|
||||
entire domains (e.g. 'example.com'). For block lists, you can optionally
|
||||
provide a reason (e.g. 'spam', 'competitor').
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing email rules, 'receive' for incoming email rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' to whitelist, 'block' to blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address (user@example.com) or domain (example.com) to add"
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was added"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
|
||||
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.create(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"entry": input_data.entry}
|
||||
if input_data.reason and input_data.list_type == ListType.BLOCK:
|
||||
params["reason"] = input_data.reason
|
||||
|
||||
result = await self.create_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetListEntryBlock(Block):
|
||||
"""
|
||||
Check if an email address or domain exists in an AgentMail allow/block list.
|
||||
|
||||
Returns the entry details if found. Use this to verify whether a specific
|
||||
address or domain is currently allowed or blocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(description="Email address or domain to look up")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was found"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object with metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
|
||||
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.get(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteListEntryBlock(Block):
|
||||
"""
|
||||
Remove an email address or domain from an AgentMail allow/block list.
|
||||
|
||||
After removal, the address/domain will no longer be filtered by this list.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address or domain to remove from the list"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the entry was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
|
||||
description="Remove an email address or domain from an allow/block list to stop filtering it.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_entry": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.lists.delete(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
|
||||
|
||||
A Message is an individual email within a Thread. Agents can send new messages
|
||||
(which create threads), reply to existing messages, forward them, and manage
|
||||
labels for state tracking (e.g. read/unread, campaign tags).
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailSendMessageBlock(Block):
|
||||
"""
|
||||
Send a new email from an AgentMail inbox, automatically creating a new thread.
|
||||
|
||||
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
|
||||
organizing messages (e.g. campaign tracking, state management).
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses for human-in-the-loop oversight",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the sent message"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID grouping this message and any future replies"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete sent message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b67469b2-7748-4d81-a223-4ebd332cca89",
|
||||
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
"subject": "Test",
|
||||
"text": "Hello",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.send(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"to": input_data.to,
|
||||
"subject": input_data.subject,
|
||||
"text": input_data.text,
|
||||
}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
msg = await self.send_message(credentials, input_data.inbox_id, **params)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListMessagesBlock(Block):
|
||||
"""
|
||||
List all messages in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Returns a paginated list of messages. Use labels to filter (e.g.
|
||||
labels=['unread'] to only get unprocessed messages). Useful for
|
||||
polling workflows or building inbox views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list messages from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of messages to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
messages: list[dict] = SchemaField(
|
||||
description="List of message objects with subject, sender, text, html, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of messages returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="721234df-c7a2-4927-b205-744badbd5844",
|
||||
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("messages", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_messages": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"messages": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_messages(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
messages = [m.model_dump() for m in response.messages]
|
||||
|
||||
yield "messages", messages
|
||||
yield "count", (c if (c := response.count) is not None else len(messages))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetMessageBlock(Block):
|
||||
"""
|
||||
Retrieve a specific email message by ID from an AgentMail inbox.
|
||||
|
||||
Returns the full message including subject, body (text and HTML),
|
||||
sender, recipients, and attachments. Use extracted_text to get
|
||||
only the new reply content without quoted history.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="Unique identifier of the message")
|
||||
thread_id: str = SchemaField(description="Thread this message belongs to")
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Full plain text body (may include quoted reply history)"
|
||||
)
|
||||
extracted_text: str = SchemaField(
|
||||
description="Just the new reply content with quoted history stripped. Best for AI processing.",
|
||||
default="",
|
||||
)
|
||||
html: str = SchemaField(description="HTML body of the email", default="")
|
||||
result: dict = SchemaField(
|
||||
description="Complete message object with all fields including sender, recipients, attachments, labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
|
||||
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("thread_id", "t1"),
|
||||
("subject", "Hi"),
|
||||
("text", "Hello"),
|
||||
("extracted_text", "Hello"),
|
||||
("html", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"thread_id": "t1",
|
||||
"subject": "Hi",
|
||||
"text": "Hello",
|
||||
"extracted_text": "Hello",
|
||||
"html": "",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_message(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get(
|
||||
inbox_id=inbox_id, message_id=message_id
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.get_message(
|
||||
credentials, input_data.inbox_id, input_data.message_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "subject", msg.subject or ""
|
||||
yield "text", msg.text or ""
|
||||
yield "extracted_text", msg.extracted_text or ""
|
||||
yield "html", msg.html or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailReplyToMessageBlock(Block):
|
||||
"""
|
||||
Reply to an existing email message, keeping the reply in the same thread.
|
||||
|
||||
The reply is automatically added to the same conversation thread as the
|
||||
original message. Use this for multi-turn agent conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send the reply from"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
text: str = SchemaField(description="Plain text body of the reply")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the reply",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the reply message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID the reply was added to")
|
||||
result: dict = SchemaField(
|
||||
description="Complete reply message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9fe53fa-5026-4547-9570-b54ccb487229",
|
||||
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"text": "Reply",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-reply-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"reply_to_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-reply-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-reply-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def reply_to_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.reply(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"text": input_data.text}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
reply = await self.reply_to_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = reply.model_dump()
|
||||
|
||||
yield "message_id", reply.message_id
|
||||
yield "thread_id", reply.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailForwardMessageBlock(Block):
|
||||
"""
|
||||
Forward an existing email message to one or more recipients.
|
||||
|
||||
Sends the original message content to different email addresses.
|
||||
Optionally prepend additional text or override the subject line.
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to forward from"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to forward")
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Override the subject line (defaults to 'Fwd: <original subject>')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Additional plain text to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Additional HTML to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the forwarded message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID of the forward")
|
||||
result: dict = SchemaField(
|
||||
description="Complete forwarded message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
|
||||
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-fwd-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"forward_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-fwd-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def forward_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.forward(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
fwd = await self.forward_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = fwd.model_dump()
|
||||
|
||||
yield "message_id", fwd.message_id
|
||||
yield "thread_id", fwd.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateMessageBlock(Block):
|
||||
"""
|
||||
Add or remove labels on an email message for state management.
|
||||
|
||||
Labels are string tags used to track message state (read/unread),
|
||||
categorize messages (billing, support), or tag campaigns (q4-outreach).
|
||||
Common pattern: add 'read' and remove 'unread' after processing a message.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to update labels on")
|
||||
add_labels: list[str] = SchemaField(
|
||||
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
|
||||
default_factory=list,
|
||||
)
|
||||
remove_labels: list[str] = SchemaField(
|
||||
description="Labels to remove (e.g. ['unread', 'pending'])",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="The updated message ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated message object with current labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="694ff816-4c89-4a5e-a552-8c31be187735",
|
||||
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"add_labels": ["read"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.update(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
if not input_data.add_labels and not input_data.remove_labels:
|
||||
raise ValueError(
|
||||
"Must specify at least one label operation: add_labels or remove_labels"
|
||||
)
|
||||
|
||||
params: dict = {}
|
||||
if input_data.add_labels:
|
||||
params["add_labels"] = input_data.add_labels
|
||||
if input_data.remove_labels:
|
||||
params["remove_labels"] = input_data.remove_labels
|
||||
|
||||
msg = await self.update_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
@@ -0,0 +1,651 @@
|
||||
"""
|
||||
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
|
||||
|
||||
Pods provide multi-tenant isolation between your customers. Each pod acts as
|
||||
an isolated workspace containing its own inboxes, domains, threads, and drafts.
|
||||
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
|
||||
serve multiple customers.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreatePodBlock(Block):
|
||||
"""
|
||||
Create a new pod for multi-tenant customer isolation.
|
||||
|
||||
Each pod acts as an isolated workspace for one customer or tenant.
|
||||
Use client_id to map pods to your internal tenant IDs for idempotent
|
||||
creation (safe to retry without creating duplicates).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
client_id: str = SchemaField(
|
||||
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the created pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
|
||||
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pod_id", "mock-pod-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "mock-pod-id",
|
||||
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.create(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.client_id:
|
||||
params["client_id"] = input_data.client_id
|
||||
|
||||
pod = await self.create_pod(credentials, **params)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetPodBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing pod by its ID.
|
||||
|
||||
Returns the pod metadata including its client_id mapping and
|
||||
creation timestamp.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553361bc-bb1b-4322-9ad4-0c226200217e",
|
||||
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("pod_id", "test-pod"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "test-pod",
|
||||
"model_dump": lambda self: {"pod_id": "test-pod"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.pods.get(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodsBlock(Block):
|
||||
"""
|
||||
List all pods in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all tenant pods with their metadata.
|
||||
Use this to see all customer workspaces at a glance.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of pods to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pods: list[dict] = SchemaField(
|
||||
description="List of pod objects with pod_id, client_id, creation time, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of pods returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d3725ee-2968-431a-a816-857ab41e1420",
|
||||
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pods", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pods": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"pods": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pods(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pods(credentials, **params)
|
||||
pods = [p.model_dump() for p in response.pods]
|
||||
|
||||
yield "pods", pods
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeletePodBlock(Block):
|
||||
"""
|
||||
Permanently delete a pod. All inboxes and domains must be removed first.
|
||||
|
||||
You cannot delete a pod that still contains inboxes or domains.
|
||||
Delete all child resources first, then delete the pod.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(
|
||||
description="Pod ID to permanently delete (must have no inboxes or domains)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the pod was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
|
||||
description="Permanently delete a pod. All inboxes and domains must be removed first.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_pod": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_pod(credentials, pod_id=input_data.pod_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodInboxesBlock(Block):
|
||||
"""
|
||||
List all inboxes within a specific pod (customer workspace).
|
||||
|
||||
Returns only the inboxes belonging to this pod, providing
|
||||
tenant-scoped visibility.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects within this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of inboxes returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
|
||||
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_inboxes(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads across all inboxes within a pod.
|
||||
|
||||
Returns threads from every inbox in the pod. Use for building
|
||||
per-customer dashboards showing all email activity, or for
|
||||
supervisor agents monitoring a customer's conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list threads from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
|
||||
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.threads.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_pod_threads(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across all inboxes within a pod.
|
||||
|
||||
Returns pending drafts from every inbox in the pod. Use for
|
||||
per-customer approval dashboards or monitoring scheduled sends.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list drafts from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
|
||||
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.drafts.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_drafts(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreatePodInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox within a specific pod (customer workspace).
|
||||
|
||||
The inbox is automatically scoped to the pod and inherits its
|
||||
isolation guarantees. If username/domain are not provided,
|
||||
AgentMail auto-generates a unique address.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
|
||||
default="",
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier of the created inbox"
|
||||
)
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6862373-1ac6-402e-89e6-7db1fea882af",
|
||||
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.create(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_pod_inbox(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
AgentMail Thread blocks — list, get, and delete conversation threads.
|
||||
|
||||
A Thread groups related messages into a single conversation. Threads are
|
||||
created automatically when a new message is sent and grow as replies are added.
|
||||
Threads can be queried per-inbox or across the entire organization.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailListInboxThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads within a specific AgentMail inbox.
|
||||
|
||||
Returns a paginated list of threads with optional label filtering.
|
||||
Use labels to find threads by campaign, status, or custom tags.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list threads from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects with thread_id, subject, message count, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="63dd9e2d-ef81-405c-b034-c031f0437334",
|
||||
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_threads(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread from an AgentMail inbox.
|
||||
|
||||
Returns the thread with all its messages in chronological order.
|
||||
Use this to get the full conversation history for context when
|
||||
composing replies.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42866290-1479-4153-83e7-550b703e9da2",
|
||||
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxThreadBlock(Block):
|
||||
"""
|
||||
Permanently delete a conversation thread and all its messages from an inbox.
|
||||
|
||||
This removes the thread and every message within it. This action
|
||||
cannot be undone.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to permanently delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the thread was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
|
||||
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_thread": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_thread(
|
||||
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgThreadsBlock(Block):
|
||||
"""
|
||||
List conversation threads across ALL inboxes in your organization.
|
||||
|
||||
Unlike per-inbox listing, this returns threads from every inbox.
|
||||
Ideal for building supervisor agents that monitor all conversations,
|
||||
analytics dashboards, or cross-agent routing workflows.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
|
||||
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_threads(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.threads.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_org_threads(credentials, **params)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetOrgThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread by ID from anywhere in the organization.
|
||||
|
||||
Works without needing to know which inbox the thread belongs to.
|
||||
Returns the thread with all its messages in chronological order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID to retrieve (works across all inboxes)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
|
||||
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_org_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.threads.get(thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_org_thread(credentials, input_data.thread_id)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -27,6 +27,7 @@ from backend.util.file import MediaFileType, store_media_file
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
NANO_BANANA_2 = "google/nano-banana-2"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -77,7 +78,7 @@ class AIImageCustomizerBlock(Block):
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
default=GeminiImageModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
@@ -103,7 +104,7 @@ class AIImageCustomizerBlock(Block):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -111,7 +112,7 @@ class AIImageCustomizerBlock(Block):
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
|
||||
@@ -115,6 +115,7 @@ class ImageGenModel(str, Enum):
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
@@ -131,7 +132,7 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
model: ImageGenModel = SchemaField(
|
||||
description="The AI model to use for image generation",
|
||||
default=ImageGenModel.SD3_5,
|
||||
default=ImageGenModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
size: ImageSize = SchemaField(
|
||||
@@ -165,7 +166,7 @@ class AIImageGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
|
||||
"model": ImageGenModel.RECRAFT,
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"size": ImageSize.SQUARE,
|
||||
"style": ImageStyle.REALISTIC,
|
||||
},
|
||||
@@ -179,7 +180,9 @@ class AIImageGeneratorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
"_run_client": lambda *args, **kwargs: (
|
||||
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -280,17 +283,24 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
elif input_data.model in (
|
||||
ImageGenModel.NANO_BANANA_PRO,
|
||||
ImageGenModel.NANO_BANANA_2,
|
||||
):
|
||||
# Use Nano Banana models (Google Gemini image variants)
|
||||
model_map = {
|
||||
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
|
||||
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
|
||||
}
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"resolution": "2K",
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
credentials, model_map[input_data.model], input_params
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
520
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
520
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any | None
|
||||
success: bool | None
|
||||
|
||||
|
||||
class TokenUsage(TypedDict):
|
||||
"""Aggregated token counts from the autopilot stream."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AutoPilotBlock(Block):
|
||||
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
|
||||
|
||||
The autopilot can manage agents, access workspace files, fetch web content,
|
||||
run blocks, and more. This block enables sub-agent patterns (autopilot calling
|
||||
autopilot) and scheduled autopilot execution via the agent executor.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for the AutoPilot block."""
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for the autopilot to execute. "
|
||||
"The autopilot has access to platform tools like agent management, "
|
||||
"workspace files, web fetch, block execution, and more."
|
||||
),
|
||||
placeholder="Find my agents and list them",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
system_context: str = SchemaField(
|
||||
description=(
|
||||
"Optional additional context prepended to the prompt. "
|
||||
"Use this to constrain autopilot behavior, provide domain "
|
||||
"context, or set output format requirements."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to continue an existing autopilot conversation. "
|
||||
"Leave empty to start a new session. "
|
||||
"Use the session_id output from a previous run to continue."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
max_recursion_depth: int = SchemaField(
|
||||
description=(
|
||||
"Maximum nesting depth when the autopilot calls this block "
|
||||
"recursively (sub-agent pattern). Prevents infinite loops."
|
||||
),
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
response: str = SchemaField(
|
||||
description="The final text response from the autopilot."
|
||||
)
|
||||
tool_calls: list[ToolCallEntry] = SchemaField(
|
||||
description=(
|
||||
"List of tools called during execution. Each entry has "
|
||||
"tool_call_id, tool_name, input, output, and success fields."
|
||||
),
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Current turn messages (user prompt + assistant reply) as JSON. "
|
||||
"It can be used for logging or analysis."
|
||||
),
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back to continue the conversation in a future run."
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsage = SchemaField(
|
||||
description=(
|
||||
"Token usage statistics: prompt_tokens, "
|
||||
"completion_tokens, total_tokens."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=AUTOPILOT_BLOCK_ID,
|
||||
description=(
|
||||
"Execute tasks using AutoGPT AutoPilot with full access to "
|
||||
"platform tools (agent management, workspace files, web fetch, "
|
||||
"block execution, and more). Enables sub-agent patterns and "
|
||||
"scheduled autopilot execution."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.AGENT},
|
||||
input_schema=AutoPilotBlock.Input,
|
||||
output_schema=AutoPilotBlock.Output,
|
||||
test_input={
|
||||
"prompt": "List my agents",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
},
|
||||
test_output=[
|
||||
("response", "You have 2 agents: Agent A and Agent B."),
|
||||
("tool_calls", []),
|
||||
(
|
||||
"conversation_history",
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
),
|
||||
("session_id", "test-session-id"),
|
||||
(
|
||||
"token_usage",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_session": lambda *args, **kwargs: "test-session-id",
|
||||
"execute_copilot": lambda *args, **kwargs: (
|
||||
"You have 2 agents: Agent A and Agent B.",
|
||||
[],
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
"test-session-id",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
self,
|
||||
prompt: str,
|
||||
system_context: str,
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
system_context: Optional context prepended to the prompt.
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
if result.tool_calls:
|
||||
turn_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
turn_messages.append(
|
||||
{"role": "assistant", "content": result.response_text}
|
||||
)
|
||||
history_json = json.dumps(turn_messages, default=str)
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
usage: TokenUsage = {
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"total_tokens": result.total_tokens,
|
||||
}
|
||||
|
||||
return (
|
||||
result.response_text,
|
||||
tool_calls,
|
||||
history_json,
|
||||
session_id,
|
||||
usage,
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Validate inputs, invoke the autopilot, and yield structured outputs.
|
||||
|
||||
Yields session_id even on failure so callers can inspect/resume the session.
|
||||
"""
|
||||
if not input_data.prompt.strip():
|
||||
yield "error", "Prompt cannot be empty."
|
||||
return
|
||||
|
||||
if not execution_context.user_id:
|
||||
yield "error", "Cannot run autopilot without an authenticated user."
|
||||
return
|
||||
|
||||
if input_data.max_recursion_depth < 1:
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
# would cancel the task mid-flight, corrupting the SDK's internal
|
||||
# anyio memory stream (see service.py CRITICAL comment).
|
||||
try:
|
||||
response, tool_calls, history, _, usage = await self.execute_copilot(
|
||||
prompt=input_data.prompt,
|
||||
system_context=input_data.system_context,
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
yield "tool_calls", tool_calls
|
||||
yield "conversation_history", history
|
||||
yield "session_id", sid
|
||||
yield "token_usage", usage
|
||||
except asyncio.CancelledError:
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – placed after the block class for top-down readability.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Task-scoped recursion depth counter & chain-wide limit.
|
||||
# contextvars are scoped to the current asyncio task, so concurrent
|
||||
# graph executions each get independent counters.
|
||||
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_depth", default=0
|
||||
)
|
||||
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_limit", default=None
|
||||
)
|
||||
|
||||
|
||||
def _check_recursion(
|
||||
max_depth: int,
|
||||
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
|
||||
"""Check and increment recursion depth.
|
||||
|
||||
Returns ContextVar tokens that must be passed to ``_reset_recursion``
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
return (
|
||||
_autopilot_recursion_depth.set(current + 1),
|
||||
_autopilot_recursion_limit.set(limit),
|
||||
)
|
||||
|
||||
|
||||
def _reset_recursion(
|
||||
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
|
||||
) -> None:
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -472,7 +472,7 @@ class AddToListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
if input_data.entry is not None:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -99,6 +100,8 @@ class SendEmailBlock(Block):
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
@@ -129,6 +132,17 @@ class SendEmailBlock(Block):
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# --- SSRF Protection ---
|
||||
smtp_port = input_data.config.smtp_port
|
||||
if smtp_port not in self.ALLOWED_SMTP_PORTS:
|
||||
yield "error", (
|
||||
f"SMTP port {smtp_port} is not allowed. "
|
||||
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
|
||||
)
|
||||
return
|
||||
|
||||
await resolve_and_check_blocked(input_data.config.smtp_server)
|
||||
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
@@ -180,7 +194,19 @@ class SendEmailBlock(Block):
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPConnectError:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}."
|
||||
)
|
||||
except smtplib.SMTPServerDisconnected:
|
||||
yield "error", (
|
||||
f"SMTP server '{input_data.config.smtp_server}' "
|
||||
"disconnected unexpectedly."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except ValueError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -34,17 +34,29 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class FluxKontextModelName(str, Enum):
|
||||
PRO = "Flux Kontext Pro"
|
||||
MAX = "Flux Kontext Max"
|
||||
class ImageEditorModel(str, Enum):
|
||||
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
|
||||
FLUX_KONTEXT_MAX = "Flux Kontext Max"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
@property
|
||||
def api_name(self) -> str:
|
||||
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
|
||||
_map = {
|
||||
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
|
||||
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
|
||||
"NANO_BANANA_PRO": "google/nano-banana-pro",
|
||||
"NANO_BANANA_2": "google/nano-banana-2",
|
||||
}
|
||||
return _map[self.name]
|
||||
|
||||
|
||||
# Keep old name as alias for backwards compatibility
|
||||
FluxKontextModelName = ImageEditorModel
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -69,7 +81,7 @@ class AIImageEditorBlock(Block):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Flux Kontext models",
|
||||
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text instruction describing the desired edit",
|
||||
@@ -87,14 +99,14 @@ class AIImageEditorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
model: FluxKontextModelName = SchemaField(
|
||||
model: ImageEditorModel = SchemaField(
|
||||
description="Model variant to use",
|
||||
default=FluxKontextModelName.PRO,
|
||||
default=ImageEditorModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
|
||||
@@ -107,7 +119,7 @@ class AIImageEditorBlock(Block):
|
||||
super().__init__(
|
||||
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
|
||||
description=(
|
||||
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
|
||||
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
|
||||
"and optional reference image to generate a modified image."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -118,7 +130,7 @@ class AIImageEditorBlock(Block):
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
"model": ImageEditorModel.NANO_BANANA_2,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -127,7 +139,9 @@ class AIImageEditorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
"run_model": lambda *args, **kwargs: (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -142,7 +156,7 @@ class AIImageEditorBlock(Block):
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
@@ -169,7 +183,7 @@ class AIImageEditorBlock(Block):
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
model: ImageEditorModel,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
@@ -178,12 +192,29 @@ class AIImageEditorBlock(Block):
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
model_name = model.api_name
|
||||
|
||||
is_nano_banana = model in (
|
||||
ImageEditorModel.NANO_BANANA_PRO,
|
||||
ImageEditorModel.NANO_BANANA_2,
|
||||
)
|
||||
if is_nano_banana:
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
|
||||
if input_image_b64:
|
||||
input_params["image_input"] = [input_image_b64]
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
|
||||
@@ -211,7 +211,7 @@ class AgentOutputBlock(Block):
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -33,6 +33,13 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.openai_responses import (
|
||||
convert_tools_to_responses_format,
|
||||
extract_responses_content,
|
||||
extract_responses_reasoning,
|
||||
extract_responses_tool_calls,
|
||||
extract_responses_usage,
|
||||
)
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
@@ -42,6 +49,9 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -111,7 +121,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
@@ -277,9 +286,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
@@ -793,6 +799,19 @@ async def llm_call(
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Sanitize unpaired surrogates in message content to prevent
|
||||
# UnicodeEncodeError when httpx encodes the JSON request body.
|
||||
for msg in prompt:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
|
||||
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
@@ -801,36 +820,53 @@ async def llm_call(
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
|
||||
|
||||
text_config = openai.omit
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
text_config = {"format": {"type": "json_object"}} # type: ignore
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
response = await oai_client.responses.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
input=prompt, # type: ignore[arg-type]
|
||||
tools=tools_param, # type: ignore[arg-type]
|
||||
max_output_tokens=max_tokens,
|
||||
parallel_tool_calls=get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
),
|
||||
text=text_config, # type: ignore[arg-type]
|
||||
store=False,
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
raw_tool_calls = extract_responses_tool_calls(response)
|
||||
tool_calls = (
|
||||
[
|
||||
ToolContentBlock(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=ToolCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
if raw_tool_calls
|
||||
else None
|
||||
)
|
||||
reasoning = extract_responses_reasoning(response)
|
||||
content = extract_responses_content(response)
|
||||
prompt_tokens, completion_tokens = extract_responses_usage(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
raw_response=response,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
response=content,
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
@@ -858,65 +894,60 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -1276,8 +1307,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
@@ -1427,7 +1460,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
|
||||
@@ -61,20 +61,27 @@ class ExecutionParams(BaseModel):
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
|
||||
# OpenAI Responses API: function_call items have type="function_call"
|
||||
if entry.get("type") == "function_call":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(call_id)
|
||||
return tool_call_ids
|
||||
|
||||
if entry.get("role") != "assistant":
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI: check for tool_calls in the entry.
|
||||
# OpenAI Chat Completions: check for tool_calls in the entry.
|
||||
calls = entry.get("tool_calls")
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
if tool_id := call.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
# Anthropics: check content items for tool_use type.
|
||||
# Anthropic: check content items for tool_use type.
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -89,16 +96,22 @@ def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool response.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
"""
|
||||
tool_call_ids: list[str] = []
|
||||
|
||||
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
|
||||
# OpenAI Responses API: function_call_output items
|
||||
if entry.get("type") == "function_call_output":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(str(call_id))
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI Chat Completions: a tool response message with role "tool".
|
||||
if entry.get("role") == "tool":
|
||||
if tool_call_id := entry.get("tool_call_id"):
|
||||
tool_call_ids.append(str(tool_call_id))
|
||||
|
||||
# Anthropics: check content items for tool_result type.
|
||||
# Anthropic: check content items for tool_result type.
|
||||
if entry.get("role") == "user":
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
@@ -111,14 +124,16 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
def _create_tool_response(
|
||||
call_id: str, output: Any, *, responses_api: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
Create a tool response message for OpenAI, Anthropic, or OpenAI Responses API,
|
||||
based on the tool_id format and the responses_api flag.
|
||||
"""
|
||||
content = output if isinstance(output, str) else json.dumps(output)
|
||||
|
||||
# Anthropics format: tool IDs typically start with "toolu_"
|
||||
# Anthropic format: tool IDs typically start with "toolu_"
|
||||
if call_id.startswith("toolu_"):
|
||||
return {
|
||||
"role": "user",
|
||||
@@ -128,8 +143,11 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
],
|
||||
}
|
||||
|
||||
# OpenAI format: tool IDs typically start with "call_".
|
||||
# Or default fallback (if the tool_id doesn't match any known prefix)
|
||||
# OpenAI Responses API format
|
||||
if responses_api:
|
||||
return {"type": "function_call_output", "call_id": call_id, "output": content}
|
||||
|
||||
# OpenAI Chat Completions format (default fallback)
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
@@ -177,10 +195,19 @@ def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str
|
||||
return tool_outputs
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
def _convert_raw_response_to_dict(
|
||||
raw_response: Any,
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
|
||||
For the OpenAI Responses API, the raw_response is the entire Response
|
||||
object. Its ``output`` items (messages, function_calls) are extracted
|
||||
individually so they can be used as valid input items on the next call.
|
||||
Returns a **list** of dicts in that case.
|
||||
|
||||
For Chat Completions / Anthropic / Ollama, returns a single dict.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
@@ -188,11 +215,28 @@ def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
elif _is_responses_api_object(raw_response):
|
||||
# OpenAI Responses API: extract individual output items
|
||||
items = [json.to_dict(item) for item in raw_response.output]
|
||||
return items if items else [{"role": "assistant", "content": ""}]
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
# Chat Completions / Anthropic return message objects
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def _is_responses_api_object(obj: Any) -> bool:
|
||||
"""Detect an OpenAI Responses API Response object.
|
||||
|
||||
These have ``object == "response"`` and an ``output`` list, but no
|
||||
``role`` attribute (unlike ChatCompletionMessage).
|
||||
"""
|
||||
return (
|
||||
getattr(obj, "object", None) == "response"
|
||||
and hasattr(obj, "output")
|
||||
and not hasattr(obj, "role")
|
||||
)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -214,9 +258,10 @@ def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str,
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
class OrchestratorBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -357,8 +402,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
input_schema=OrchestratorBlock.Input,
|
||||
output_schema=OrchestratorBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
@@ -396,7 +441,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -407,7 +452,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
if is_dynamic:
|
||||
@@ -441,7 +486,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
@@ -498,7 +543,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": OrchestratorBlock.cleanup(tool_name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -508,7 +553,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
clean_field_name = OrchestratorBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
@@ -574,17 +619,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
tool_func = await OrchestratorBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
tool_func = await OrchestratorBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
@@ -754,19 +795,34 @@ class SmartDecisionMakerBlock(Block):
|
||||
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||
):
|
||||
"""Update conversation history with response and tool outputs."""
|
||||
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||
item.get("type") == "tool_use"
|
||||
for item in assistant_message.get("content", [])
|
||||
)
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
if isinstance(converted, list):
|
||||
# Responses API: output items are already individual dicts
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
)
|
||||
|
||||
prompt.append(assistant_message)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
# Chat Completions / Anthropic: single assistant message dict
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.append(converted)
|
||||
|
||||
if tool_outputs:
|
||||
prompt.extend(tool_outputs)
|
||||
@@ -776,6 +832,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_info: ToolInfo,
|
||||
execution_params: ExecutionParams,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
*,
|
||||
responses_api: bool = False,
|
||||
) -> dict:
|
||||
"""Execute a single tool using the execution manager for proper integration."""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
@@ -847,7 +905,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
@@ -868,13 +926,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(tool_call.id, tool_response_content)
|
||||
return _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
logger.warning(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {str(e)}",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
|
||||
async def _execute_tools_agent_mode(
|
||||
@@ -895,6 +957,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""Execute tools in agent mode with a loop until finished."""
|
||||
max_iterations = input_data.agent_mode_max_iterations
|
||||
iteration = 0
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
# Execution parameters for tool execution
|
||||
execution_params = ExecutionParams(
|
||||
@@ -951,14 +1014,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
for tool_info in processed_tools:
|
||||
try:
|
||||
tool_response = await self._execute_single_tool_with_manager(
|
||||
tool_info, execution_params, execution_processor
|
||||
tool_info,
|
||||
execution_params,
|
||||
execution_processor,
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
tool_outputs.append(tool_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
# Create error response for the tool
|
||||
error_response = _create_tool_response(
|
||||
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||
tool_info.tool_call.id,
|
||||
f"Error: {str(e)}",
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
tool_outputs.append(error_response)
|
||||
|
||||
@@ -1020,11 +1088,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
_create_tool_response(
|
||||
first_call_id,
|
||||
input_data.last_tool_output,
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
@@ -1035,7 +1109,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"[OrchestratorBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or the tool giving response more than once."
|
||||
f"This should not happen! Please check the conversation history for any inconsistencies."
|
||||
@@ -1050,11 +1124,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await llm.fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p.get("role") == "system"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1065,7 +1143,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p.get("role") == "user"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1166,18 +1246,33 @@ class SmartDecisionMakerBlock(Block):
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
"[OrchestratorBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
if response.reasoning:
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
|
||||
# Check for tool calls to avoid inserting reasoning between tool pairs
|
||||
if isinstance(converted, list):
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
)
|
||||
else:
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
if isinstance(converted, list):
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
prompt.append(converted)
|
||||
|
||||
yield "conversations", prompt
|
||||
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -28,46 +23,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,13 +103,10 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -185,32 +137,28 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -242,24 +190,22 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -282,32 +228,33 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -335,13 +282,10 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -367,24 +311,21 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
_reset_recursion,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
|
||||
"""Helper to build an ExecutionContext for tests."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="graph-1",
|
||||
graph_exec_id="gexec-1",
|
||||
graph_version=1,
|
||||
node_id="node-1",
|
||||
node_exec_id="nexec-1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursion guard unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecursion:
|
||||
"""Unit tests for _check_recursion / _reset_recursion."""
|
||||
|
||||
def test_first_call_increments_depth(self):
|
||||
tokens = _check_recursion(3)
|
||||
try:
|
||||
assert _autopilot_recursion_depth.get() == 1
|
||||
assert _autopilot_recursion_limit.get() == 3
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
def test_reset_restores_previous_values(self):
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
tokens = _check_recursion(5)
|
||||
_reset_recursion(tokens)
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
|
||||
def test_exceeding_limit_raises(self):
|
||||
t1 = _check_recursion(2)
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_nested_calls_respect_inherited_limit(self):
|
||||
"""Inner call with higher max_depth still respects outer limit."""
|
||||
t1 = _check_recursion(2) # sets limit=2
|
||||
try:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run() validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
"""Tests for input validation in AutoPilotBlock.run()."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_yields_error(self, block):
|
||||
block.Input # ensure schema is accessible
|
||||
input_data = block.Input(prompt=" ", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert outputs.get("error") == "Prompt cannot be empty."
|
||||
assert "response" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_yields_error(self, block):
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context(user_id="")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert "authenticated user" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run_yields_all_outputs(self, block):
|
||||
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
|
||||
mock_result = (
|
||||
"Hello world",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"sess-abc",
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-abc")
|
||||
|
||||
input_data = block.Input(prompt="hi", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["response"] == "Hello world"
|
||||
assert outputs["tool_calls"] == []
|
||||
assert outputs["session_id"] == "sess-abc"
|
||||
assert outputs["token_usage"]["total_tokens"] == 15
|
||||
assert "error" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_yields_error(self, block):
|
||||
"""On unexpected failure, run() should yield an error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
block.create_session = AsyncMock(return_value="sess-fail")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-fail"
|
||||
assert "boom" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_yields_error_and_reraises(self, block):
|
||||
"""CancelledError should yield error, then re-raise."""
|
||||
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"existing-sid",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock()
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="test", session_id="existing-sid", max_recursion_depth=3
|
||||
)
|
||||
ctx = _make_context()
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block registration / ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockRegistration:
|
||||
def test_block_id_matches_constant(self):
|
||||
block = AutoPilotBlock()
|
||||
assert block.id == AUTOPILOT_BLOCK_ID
|
||||
|
||||
def test_max_recursion_depth_has_upper_bound(self):
|
||||
"""Schema should enforce le=10."""
|
||||
schema = AutoPilotBlock.Input.model_json_schema()
|
||||
max_rec = schema["properties"]["max_recursion_depth"]
|
||||
assert (
|
||||
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
|
||||
)
|
||||
|
||||
def test_output_schema_has_no_duplicate_error_field(self):
|
||||
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
@@ -1,9 +1,18 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
|
||||
# but not at the type level. Cast once here to avoid per-test suppressors.
|
||||
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
class TestLLMStatsTracking:
|
||||
"""Test that LLM blocks correctly track token usage statistics."""
|
||||
@@ -13,18 +22,17 @@ class TestLLMStatsTracking:
|
||||
"""Test that llm_call returns proper token counts in LLMResponse."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
# Mock the OpenAI client
|
||||
# Mock the OpenAI Responses API response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [
|
||||
MagicMock(message=MagicMock(content="Test response", tool_calls=None))
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
|
||||
mock_response.output_text = "Test response"
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
|
||||
# Test with mocked OpenAI response
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
mock_client.responses.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
@@ -271,30 +279,17 @@ class TestLLMStatsTracking:
|
||||
mock_response = MagicMock()
|
||||
# Return different responses for chunk summary vs final summary
|
||||
if call_count == 1:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_response.output_text = '<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>'
|
||||
else:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=50, completion_tokens=30)
|
||||
mock_response.output_text = '<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>'
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=30)
|
||||
return mock_response
|
||||
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = mock_create
|
||||
mock_client.responses.create = mock_create
|
||||
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
@@ -669,3 +664,148 @@ class TestAITextSummarizerValidation:
|
||||
error_message = str(exc_info.value)
|
||||
assert "Expected a string summary" in error_message
|
||||
assert "received dict" in error_message
|
||||
|
||||
|
||||
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
|
||||
"""Create an anthropic.APIStatusError with the given status code."""
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code, request=request)
|
||||
return anthropic.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
|
||||
"""Create an openai.APIStatusError with the given status code."""
|
||||
response = httpx.Response(
|
||||
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
|
||||
)
|
||||
return openai.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
class TestUserErrorStatusCodeHandling:
|
||||
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
|
||||
and are logged as warnings, while server errors (500) trigger retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 Anthropic errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 OpenAI errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_openai_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_error_retries(self):
|
||||
"""500 errors should be retried (not break immediately)."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(500)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count > 1
|
||||
), f"Expected multiple retry attempts for 500, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_error_logs_warning_not_exception(self):
|
||||
"""User-caused errors should log with logger.warning, not logger.exception."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
raise _make_anthropic_status_error(401)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(llm.logger, "warning") as mock_warning,
|
||||
patch.object(llm.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
@@ -57,7 +57,7 @@ async def execute_graph(
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -120,7 +120,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -169,7 +169,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -198,12 +198,12 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
async def test_orchestrator_tracks_llm_stats():
|
||||
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the llm.llm_call function to return controlled data
|
||||
mock_response = MagicMock()
|
||||
@@ -224,14 +224,14 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -274,12 +274,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
async def test_orchestrator_parameter_validation():
|
||||
"""Test that OrchestratorBlock correctly validates tool call parameters."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
@@ -327,13 +327,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -394,13 +394,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -454,13 +454,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -518,13 +518,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -562,12 +562,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
async def test_orchestrator_raw_response_conversion():
|
||||
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
@@ -637,7 +637,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
@@ -646,7 +646,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -715,12 +715,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -771,12 +771,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -811,12 +811,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
async def test_orchestrator_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
@@ -893,7 +893,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -929,7 +929,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -969,12 +969,12 @@ async def test_smart_decision_maker_agent_mode():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
@@ -1018,7 +1018,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -1060,12 +1060,12 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1080,7 +1080,7 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1091,12 +1091,12 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
async def test_orchestrator_falls_back_to_block_name():
|
||||
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1111,7 +1111,7 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1122,11 +1122,11 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
async def test_orchestrator_uses_customized_name_for_agents():
|
||||
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1152,10 +1152,10 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1166,11 +1166,11 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1196,10 +1196,10 @@ async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -23,24 +23,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_list_fields():
|
||||
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -86,18 +86,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -290,7 +290,9 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, input_data.values
|
||||
)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -115,10 +115,22 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
"The actual decision also requires ``api_key`` and ``base_url`` — "
|
||||
"use the ``openrouter_active`` property for the final answer.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
test_mode: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy service instead of real LLM calls. "
|
||||
"Send __test_transient_error__, __test_fatal_error__, or "
|
||||
"__test_slow_response__ to trigger specific scenarios.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -136,7 +148,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
@@ -146,6 +158,21 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
base = (self.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -168,15 +195,6 @@ class ChatConfig(BaseSettings):
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
@@ -219,26 +237,6 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -248,6 +246,7 @@ class ChatConfig(BaseSettings):
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_prefix = "CHAT_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -6,19 +6,70 @@ from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _ENV_VARS_TO_CLEAR:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
def test_enabled_with_credentials_returns_true(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_enabled_but_missing_api_key_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_strips_v1_suffix_and_still_valid(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_invalid_base_url_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="not-a-url",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
"[__COPILOT_RETRYABLE_ERROR_a9c2__]" # ErrorCard + retry
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
@@ -35,3 +38,24 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
|
||||
def is_transient_api_error(error_text: str) -> bool:
|
||||
"""Return True if *error_text* matches a known transient Anthropic API error."""
|
||||
lower = error_text.lower()
|
||||
return any(pat.lower() in lower for pat in _TRANSIENT_ERROR_PATTERNS)
|
||||
|
||||
@@ -17,8 +17,20 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -34,17 +46,33 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Current execution's capability filter. None means "no restrictions".
|
||||
# Set by set_execution_context(); read by run_block and service.py.
|
||||
_current_permissions: "ContextVar[CopilotPermissions | None]" = ContextVar(
|
||||
"_current_permissions", default=None
|
||||
)
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
@@ -52,6 +80,7 @@ def set_execution_context(
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -59,6 +88,11 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_permissions() -> "CopilotPermissions | None":
|
||||
"""Return the capability filter for the current execution, or None if unrestricted."""
|
||||
return _current_permissions.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
@@ -70,17 +104,32 @@ def get_sdk_cwd() -> str:
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
E2B_ALLOWED_DIRS: tuple[str, ...] = (E2B_WORKDIR, "/tmp")
|
||||
E2B_ALLOWED_DIRS_STR: str = " or ".join(E2B_ALLOWED_DIRS)
|
||||
|
||||
|
||||
def is_within_allowed_dirs(path: str) -> bool:
|
||||
"""Return True if *path* is within one of the allowed sandbox directories."""
|
||||
for allowed in E2B_ALLOWED_DIRS:
|
||||
if path == allowed or path.startswith(allowed + "/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
"""Normalise *path* to an absolute sandbox path under an allowed directory.
|
||||
|
||||
Allowed directories: ``/home/user`` and ``/tmp``.
|
||||
Relative paths are resolved against ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
if not is_within_allowed_dirs(normalized):
|
||||
raise ValueError(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(path)}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
@@ -100,7 +149,9 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -119,10 +170,22 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,8 +9,9 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_permissions,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
@@ -18,6 +19,7 @@ from backend.copilot.context import (
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
@@ -61,6 +63,19 @@ def test_get_current_sandbox_returns_set_value():
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_set_and_get_current_permissions():
|
||||
"""set_execution_context stores permissions; get_current_permissions returns it."""
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
set_execution_context("u1", _make_session(), permissions=perms)
|
||||
assert get_current_permissions() is perms
|
||||
|
||||
|
||||
def test_get_current_permissions_defaults_to_none():
|
||||
"""get_current_permissions returns None when no permissions have been set."""
|
||||
set_execution_context("u1", _make_session())
|
||||
assert get_current_permissions() is None
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
@@ -104,11 +119,13 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -117,10 +134,22 @@ def test_is_allowed_local_path_tool_results_dir():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -129,6 +158,21 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -154,10 +198,32 @@ def test_resolve_sandbox_path_normalizes_dots():
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_allowed():
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_nested():
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_itself():
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_escape_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_prefix_collision_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
@@ -14,14 +14,16 @@ import time
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
@@ -152,8 +154,6 @@ class CoPilotProcessor:
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
@@ -246,48 +246,58 @@ class CoPilotProcessor:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
# publishing (shared with collect_copilot_response).
|
||||
raw_stream = stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
turn_id=entry.turn_id,
|
||||
stream=raw_stream,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
# Capture StreamError so mark_session_completed receives
|
||||
# the error message (stream_and_publish yields but does
|
||||
# not publish StreamError — that's done by mark_session_completed).
|
||||
if isinstance(chunk, StreamError):
|
||||
error_msg = chunk.errorText
|
||||
break
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
log.info("Stream cancelled by user")
|
||||
|
||||
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.copilot.providers import SUPPORTED_PROVIDERS
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Derived from the single SUPPORTED_PROVIDERS registry for backward compat.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
slug: entry["env_vars"] for slug, entry in SUPPORTED_PROVIDERS.items()
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
try:
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
except RuntimeError:
|
||||
# Hook already registered (e.g. module re-import in tests).
|
||||
pass
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch %s credentials for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"discarding stale token to force re-auth",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Copilot execution permissions — tool and block allow/deny filtering.
|
||||
|
||||
:class:`CopilotPermissions` is the single model used everywhere:
|
||||
|
||||
- ``AutoPilotBlock`` reads four block-input fields and builds one instance.
|
||||
- ``stream_chat_completion_sdk`` applies it when constructing
|
||||
``ClaudeAgentOptions.allowed_tools`` / ``disallowed_tools``.
|
||||
- ``run_block`` reads it from the contextvar to gate block execution.
|
||||
- Recursive (sub-agent) invocations merge parent and child so children
|
||||
can only be *more* restrictive, never more permissive.
|
||||
|
||||
Tool names
|
||||
----------
|
||||
Users specify the **short name** as it appears in ``TOOL_REGISTRY`` (e.g.
|
||||
``run_block``, ``web_fetch``) or as an SDK built-in (e.g. ``Read``,
|
||||
``Task``, ``WebSearch``). Internally these are mapped to the full SDK
|
||||
format (``mcp__copilot__run_block``, ``Read``, …) by
|
||||
:func:`apply_tool_permissions`.
|
||||
|
||||
Block identifiers
|
||||
-----------------
|
||||
Each entry in ``blocks`` may be one of:
|
||||
|
||||
- A **full UUID** (``c069dc6b-c3ed-4c12-b6e5-d47361e64ce6``)
|
||||
- A **partial UUID** — the first 8-character hex segment (``c069dc6b``)
|
||||
- A **block name** (case-insensitive, e.g. ``"HTTP Request"``)
|
||||
|
||||
:func:`validate_block_identifiers` resolves all entries against the live
|
||||
block registry and returns any that could not be matched.
|
||||
|
||||
Semantics
|
||||
---------
|
||||
``tools_exclude=True`` (default) — ``tools`` is a **blacklist**; listed
|
||||
tools are denied and everything else is allowed. An empty list means
|
||||
"allow all" (no filtering).
|
||||
|
||||
``tools_exclude=False`` — ``tools`` is a **whitelist**; only listed tools
|
||||
are allowed.
|
||||
|
||||
``blocks_exclude`` follows the same pattern for ``blocks``.
|
||||
|
||||
Recursion inheritance
|
||||
---------------------
|
||||
:meth:`CopilotPermissions.merged_with_parent` produces a new instance that
|
||||
is at most as permissive as the parent:
|
||||
|
||||
- Tools: effective-allowed sets are intersected then stored as a whitelist.
|
||||
- Blocks: the parent is stored in ``_parent`` and consulted during every
|
||||
:meth:`is_block_allowed` call so both constraints must pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — single source of truth for all accepted tool names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Literal type combining all valid tool names — used by AutoPilotBlock.Input
|
||||
# so the frontend renders a multi-select dropdown.
|
||||
# This is the SINGLE SOURCE OF TRUTH. All other name sets are derived from it.
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
"browser_screenshot",
|
||||
"connect_integration",
|
||||
"continue_run_block",
|
||||
"create_agent",
|
||||
"create_feature_request",
|
||||
"create_folder",
|
||||
"customize_agent",
|
||||
"delete_folder",
|
||||
"delete_workspace_file",
|
||||
"edit_agent",
|
||||
"find_agent",
|
||||
"find_block",
|
||||
"find_library_agent",
|
||||
"fix_agent_graph",
|
||||
"get_agent_building_guide",
|
||||
"get_doc_page",
|
||||
"get_mcp_guide",
|
||||
"list_folders",
|
||||
"list_workspace_files",
|
||||
"move_agents_to_folder",
|
||||
"move_folder",
|
||||
"read_workspace_file",
|
||||
"run_agent",
|
||||
"run_block",
|
||||
"run_mcp_tool",
|
||||
"search_docs",
|
||||
"search_feature_requests",
|
||||
"update_folder",
|
||||
"validate_agent_graph",
|
||||
"view_agent_output",
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Read",
|
||||
"Task",
|
||||
"TodoWrite",
|
||||
"WebSearch",
|
||||
"Write",
|
||||
]
|
||||
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
PLATFORM_TOOL_NAMES: frozenset[str] = ALL_TOOL_NAMES - SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
# Compiled regex patterns for block identifier classification.
|
||||
_FULL_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PARTIAL_UUID_RE = re.compile(r"^[0-9a-f]{8}$", re.IGNORECASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — block identifier matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _block_matches(identifier: str, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if *identifier* resolves to the given block.
|
||||
|
||||
Resolution order:
|
||||
1. Full UUID — exact case-insensitive match against *block_id*.
|
||||
2. Partial UUID (8 hex chars, first segment) — prefix match.
|
||||
3. Name — case-insensitive equality against *block_name*.
|
||||
"""
|
||||
ident = identifier.strip()
|
||||
if _FULL_UUID_RE.match(ident):
|
||||
return ident.lower() == block_id.lower()
|
||||
if _PARTIAL_UUID_RE.match(ident):
|
||||
return block_id.lower().startswith(ident.lower())
|
||||
return ident.lower() == block_name.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CopilotPermissions(BaseModel):
|
||||
"""Capability filter for a single copilot execution.
|
||||
|
||||
Attributes:
|
||||
tools: Tool names to filter (short names, e.g. ``run_block``).
|
||||
tools_exclude: When True (default) ``tools`` is a blacklist;
|
||||
when False it is a whitelist. Ignored when *tools* is empty.
|
||||
blocks: Block identifiers (name, full UUID, or 8-char partial UUID).
|
||||
blocks_exclude: Same semantics as *tools_exclude* but for blocks.
|
||||
"""
|
||||
|
||||
tools: list[str] = []
|
||||
tools_exclude: bool = True
|
||||
blocks: list[str] = []
|
||||
blocks_exclude: bool = True
|
||||
|
||||
# Private: parent permissions for recursion inheritance.
|
||||
# Set only by merged_with_parent(); never exposed in block input schema.
|
||||
_parent: CopilotPermissions | None = PrivateAttr(default=None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def effective_allowed_tools(self, all_tools: frozenset[str]) -> frozenset[str]:
|
||||
"""Compute the set of short tool names that are permitted.
|
||||
|
||||
Args:
|
||||
all_tools: Universe of valid short tool names.
|
||||
|
||||
Returns:
|
||||
Subset of *all_tools* that pass the filter.
|
||||
"""
|
||||
if not self.tools:
|
||||
return frozenset(all_tools)
|
||||
tool_set = frozenset(self.tools)
|
||||
if self.tools_exclude:
|
||||
return all_tools - tool_set
|
||||
return all_tools & tool_set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Block helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_block_allowed(self, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if the block may be executed under these permissions.
|
||||
|
||||
Checks this instance first, then consults the parent (if any) so
|
||||
the entire inheritance chain is respected.
|
||||
"""
|
||||
if not self._check_block_locally(block_id, block_name):
|
||||
return False
|
||||
if self._parent is not None:
|
||||
return self._parent.is_block_allowed(block_id, block_name)
|
||||
return True
|
||||
|
||||
def _check_block_locally(self, block_id: str, block_name: str) -> bool:
|
||||
"""Check *only* this instance's block filter (ignores parent)."""
|
||||
if not self.blocks:
|
||||
return True # No filter → allow all
|
||||
matched = any(
|
||||
_block_matches(identifier, block_id, block_name)
|
||||
for identifier in self.blocks
|
||||
)
|
||||
return not matched if self.blocks_exclude else matched
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recursion / merging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def merged_with_parent(
|
||||
self,
|
||||
parent: CopilotPermissions,
|
||||
all_tools: frozenset[str],
|
||||
) -> CopilotPermissions:
|
||||
"""Return a new instance that is at most as permissive as *parent*.
|
||||
|
||||
- Tools: intersection of effective-allowed sets, stored as a whitelist.
|
||||
- Blocks: parent is stored internally; both constraints are applied
|
||||
during :meth:`is_block_allowed`.
|
||||
"""
|
||||
merged_tools = self.effective_allowed_tools(
|
||||
all_tools
|
||||
) & parent.effective_allowed_tools(all_tools)
|
||||
result = CopilotPermissions(
|
||||
tools=sorted(merged_tools),
|
||||
tools_exclude=False,
|
||||
blocks=self.blocks,
|
||||
blocks_exclude=self.blocks_exclude,
|
||||
)
|
||||
result._parent = parent
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True when no filtering is configured (allow-all passthrough)."""
|
||||
return not self.tools and not self.blocks and self._parent is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def all_known_tool_names() -> frozenset[str]:
|
||||
"""Return all short tool names accepted in *tools*.
|
||||
|
||||
Returns the pre-computed ``ALL_TOOL_NAMES`` set (derived from the
|
||||
``ToolName`` Literal). On first call, also verifies consistency with
|
||||
the live ``TOOL_REGISTRY``.
|
||||
"""
|
||||
_assert_tool_names_consistent()
|
||||
return ALL_TOOL_NAMES
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str]) -> list[str]:
|
||||
"""Return entries in *tools* that are not valid tool names.
|
||||
|
||||
Args:
|
||||
tools: List of short tool name strings to validate.
|
||||
|
||||
Returns:
|
||||
List of invalid names (empty if all are valid).
|
||||
"""
|
||||
return [t for t in tools if t not in ALL_TOOL_NAMES]
|
||||
|
||||
|
||||
_tool_names_checked = False
|
||||
|
||||
|
||||
def _assert_tool_names_consistent() -> None:
|
||||
"""Verify that ``PLATFORM_TOOL_NAMES`` matches ``TOOL_REGISTRY`` keys.
|
||||
|
||||
Called once lazily (TOOL_REGISTRY has heavy imports). Raises
|
||||
``AssertionError`` with a helpful diff if they diverge.
|
||||
"""
|
||||
global _tool_names_checked
|
||||
if _tool_names_checked:
|
||||
return
|
||||
_tool_names_checked = True
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
registry_keys: frozenset[str] = frozenset(TOOL_REGISTRY.keys())
|
||||
declared: frozenset[str] = PLATFORM_TOOL_NAMES
|
||||
if registry_keys != declared:
|
||||
missing = registry_keys - declared
|
||||
extra = declared - registry_keys
|
||||
parts: list[str] = [
|
||||
"PLATFORM_TOOL_NAMES in permissions.py is out of sync with TOOL_REGISTRY."
|
||||
]
|
||||
if missing:
|
||||
parts.append(f" Missing from PLATFORM_TOOL_NAMES: {sorted(missing)}")
|
||||
if extra:
|
||||
parts.append(f" Extra in PLATFORM_TOOL_NAMES: {sorted(extra)}")
|
||||
parts.append(" Update the ToolName Literal to match.")
|
||||
raise AssertionError("\n".join(parts))
|
||||
|
||||
|
||||
async def validate_block_identifiers(
|
||||
identifiers: list[str],
|
||||
) -> list[str]:
|
||||
"""Resolve each block identifier and return those that could not be matched.
|
||||
|
||||
Args:
|
||||
identifiers: List of block identifiers (name, full UUID, or partial UUID).
|
||||
|
||||
Returns:
|
||||
List of identifiers that matched no known block.
|
||||
"""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# get_blocks() returns dict[block_id_str, BlockClass]; instantiate once to get names.
|
||||
block_registry = get_blocks()
|
||||
block_info = {bid: cls().name for bid, cls in block_registry.items()}
|
||||
invalid: list[str] = []
|
||||
for ident in identifiers:
|
||||
matched = any(
|
||||
_block_matches(ident, bid, bname) for bid, bname in block_info.items()
|
||||
)
|
||||
if not matched:
|
||||
invalid.append(ident)
|
||||
return invalid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK tool-list application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_tool_permissions(
|
||||
permissions: CopilotPermissions,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
|
||||
|
||||
Takes the base allowed/disallowed lists from
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
|
||||
applies *permissions* on top.
|
||||
|
||||
Returns:
|
||||
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
|
||||
possibly-narrowed list to pass to ``ClaudeAgentOptions.allowed_tools``
|
||||
and *extra_disallowed* is the list to pass to
|
||||
``ClaudeAgentOptions.disallowed_tools``.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
if permissions.is_empty():
|
||||
return base_allowed, base_disallowed
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
effective = permissions.effective_allowed_tools(all_tools)
|
||||
|
||||
# In E2B mode, SDK built-in file tools (Read, Write, Edit, Glob, Grep)
|
||||
# are replaced by MCP equivalents (read_file, write_file, ...).
|
||||
# Map each SDK built-in name to its E2B MCP name so users can use the
|
||||
# familiar names in their permissions and the E2B tools are included.
|
||||
_SDK_TO_E2B: dict[str, str] = {}
|
||||
if use_e2b:
|
||||
from backend.copilot.sdk.e2b_file_tools import E2B_FILE_TOOL_NAMES
|
||||
|
||||
_SDK_TO_E2B = dict(
|
||||
zip(
|
||||
["Read", "Write", "Edit", "Glob", "Grep"],
|
||||
E2B_FILE_TOOL_NAMES,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Build an updated allowed list by mapping short names → SDK names and
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_E2B:
|
||||
# E2B mode: map SDK built-in file tool to its MCP equivalent.
|
||||
names.append(f"{MCP_TOOL_PREFIX}{_SDK_TO_E2B[short]}")
|
||||
else:
|
||||
names.append(short) # SDK built-in — used as-is
|
||||
return names
|
||||
|
||||
# short names permitted by permissions
|
||||
permitted_sdk: set[str] = set()
|
||||
for s in effective:
|
||||
permitted_sdk.update(to_sdk_names(s))
|
||||
# Always include the internal Read tool (used by SDK for large/truncated outputs)
|
||||
permitted_sdk.add(f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}")
|
||||
|
||||
filtered_allowed = [t for t in base_allowed if t in permitted_sdk]
|
||||
|
||||
# Extra disallowed = tools that were in base_allowed but are now removed
|
||||
removed = set(base_allowed) - set(filtered_allowed)
|
||||
extra_disallowed = list(set(base_disallowed) | removed)
|
||||
|
||||
return filtered_allowed, extra_disallowed
|
||||
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for CopilotPermissions — tool/block capability filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.permissions import (
|
||||
ALL_TOOL_NAMES,
|
||||
PLATFORM_TOOL_NAMES,
|
||||
SDK_BUILTIN_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
_block_matches,
|
||||
all_known_tool_names,
|
||||
apply_tool_permissions,
|
||||
validate_block_identifiers,
|
||||
validate_tool_names,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _block_matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockMatches:
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
def test_full_uuid_match(self):
|
||||
assert _block_matches(self.BLOCK_ID, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_case_insensitive(self):
|
||||
assert _block_matches(self.BLOCK_ID.upper(), self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_no_match(self):
|
||||
other = "aaaaaaaa-0000-0000-0000-000000000000"
|
||||
assert not _block_matches(other, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_match(self):
|
||||
assert _block_matches("c069dc6b", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_case_insensitive(self):
|
||||
assert _block_matches("C069DC6B", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_no_match(self):
|
||||
assert not _block_matches("deadbeef", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_match(self):
|
||||
assert _block_matches("HTTP Request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_case_insensitive(self):
|
||||
assert _block_matches("http request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
assert _block_matches("HTTP REQUEST", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_no_match(self):
|
||||
assert not _block_matches("Unknown Block", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_not_matching_as_name(self):
|
||||
# "c069dc6b" is 8 hex chars → treated as partial UUID, NOT name match
|
||||
assert not _block_matches(
|
||||
"c069dc6b", "ffffffff-0000-0000-0000-000000000000", "c069dc6b"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.effective_allowed_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ALL_TOOLS = frozenset(
|
||||
["run_block", "web_fetch", "bash_exec", "find_agent", "Task", "Read"]
|
||||
)
|
||||
|
||||
|
||||
class TestEffectiveAllowedTools:
|
||||
def test_empty_list_allows_all(self):
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
# edge: tools_exclude=False but empty list → allow all
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=False)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_blacklist_removes_listed(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec", "web_fetch"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert "bash_exec" not in result
|
||||
assert "web_fetch" not in result
|
||||
assert "run_block" in result
|
||||
assert "Task" in result
|
||||
|
||||
def test_whitelist_keeps_only_listed(self):
|
||||
perms = CopilotPermissions(tools=["run_block", "Task"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset(["run_block", "Task"])
|
||||
|
||||
def test_whitelist_unknown_tool_yields_empty(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset()
|
||||
|
||||
def test_blacklist_unknown_tool_ignored(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == ALL_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_block_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
|
||||
class TestIsBlockAllowed:
|
||||
def test_empty_allows_everything(self):
|
||||
perms = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_blocks_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_allows_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_allows_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_blocks_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=False)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_blacklist(self):
|
||||
perms = CopilotPermissions(blocks=["c069dc6b"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_whitelist(self):
|
||||
perms = CopilotPermissions(blocks=[BLOCK_ID], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_blocks_when_child_allows(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_allows_when_child_blocks(self):
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_both_must_allow(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child._parent = parent
|
||||
assert child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_grandparent_blocks_propagate(self):
|
||||
grandparent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
parent._parent = grandparent
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.merged_with_parent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedWithParent:
|
||||
def test_tool_intersection(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_parent_whitelist_narrows_child(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True) # allow all
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert effective == frozenset(["run_block"])
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
# bash_exec was not in parent's whitelist → must not appear
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_merged_stored_as_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch"])
|
||||
parent = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
assert not merged.tools_exclude # stored as whitelist
|
||||
assert set(merged.tools) == {"run_block", "web_fetch"}
|
||||
|
||||
def test_block_parent_stored(self):
|
||||
all_t = frozenset(["run_block"])
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
# Parent restriction is preserved via _parent
|
||||
assert not merged.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_empty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_default_is_empty(self):
|
||||
assert CopilotPermissions().is_empty()
|
||||
|
||||
def test_with_tools_not_empty(self):
|
||||
assert not CopilotPermissions(tools=["bash_exec"]).is_empty()
|
||||
|
||||
def test_with_blocks_not_empty(self):
|
||||
assert not CopilotPermissions(blocks=["HTTP Request"]).is_empty()
|
||||
|
||||
def test_with_parent_not_empty(self):
|
||||
perms = CopilotPermissions()
|
||||
perms._parent = CopilotPermissions(tools=["bash_exec"])
|
||||
assert not perms.is_empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_tool_names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateToolNames:
|
||||
def test_valid_registry_tool(self):
|
||||
assert validate_tool_names(["run_block", "web_fetch"]) == []
|
||||
|
||||
def test_valid_sdk_builtin(self):
|
||||
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
|
||||
|
||||
def test_invalid_tool(self):
|
||||
result = validate_tool_names(["nonexistent_tool"])
|
||||
assert "nonexistent_tool" in result
|
||||
|
||||
def test_mixed(self):
|
||||
result = validate_tool_names(["run_block", "fake_tool"])
|
||||
assert "fake_tool" in result
|
||||
assert "run_block" not in result
|
||||
|
||||
def test_empty_list(self):
|
||||
assert validate_tool_names([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_block_identifiers (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateBlockIdentifiers:
|
||||
async def test_empty_list(self):
|
||||
result = await validate_block_identifiers([])
|
||||
assert result == []
|
||||
|
||||
async def test_valid_full_uuid(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(
|
||||
["c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"]
|
||||
)
|
||||
assert result == []
|
||||
|
||||
async def test_invalid_identifier(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["totally_unknown"])
|
||||
assert "totally_unknown" in result
|
||||
|
||||
async def test_partial_uuid_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["c069dc6b"])
|
||||
assert result == []
|
||||
|
||||
async def test_name_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["http request"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# apply_tool_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplyToolPermissions:
|
||||
def test_empty_permissions_returns_base_unchanged(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=["mcp__copilot__run_block", "mcp__copilot__web_fetch", "Task"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
perms = CopilotPermissions()
|
||||
allowed, disallowed = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" in allowed
|
||||
|
||||
def test_blacklist_removes_tool(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"mcp__copilot__bash_exec",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{
|
||||
"run_block": object(),
|
||||
"web_fetch": object(),
|
||||
"bash_exec": object(),
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "bash_exec", "Task"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__bash_exec" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_whitelist_keeps_only_listed(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "Task", "WebSearch"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" not in allowed
|
||||
assert "Task" not in allowed
|
||||
|
||||
def test_read_tool_always_included_even_when_blacklisted(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even if Read is explicitly blacklisted."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Explicitly blacklist Read
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "Task" in allowed
|
||||
|
||||
def test_read_tool_always_included_with_narrow_whitelist(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even when not in a whitelist."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Whitelist only run_block — Read not listed
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_e2b_file_tools_included_when_sdk_builtin_whitelisted(self, mocker):
|
||||
"""In E2B mode, whitelisting 'Read' must include mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"mcp__copilot__write_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Write", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Whitelist Read and run_block — E2B read_file should be included
|
||||
perms = CopilotPermissions(tools=["Read", "run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# Write not whitelisted — write_file should NOT be included
|
||||
assert "mcp__copilot__write_file" not in allowed
|
||||
|
||||
def test_e2b_file_tools_excluded_when_sdk_builtin_blacklisted(self, mocker):
|
||||
"""In E2B mode, blacklisting 'Read' must also remove mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Blacklist Read — E2B read_file should also be removed
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# mcp__copilot__Read is always preserved for SDK internals
|
||||
assert "mcp__copilot__Read" in allowed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_BUILTIN_TOOL_NAMES sanity check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
registry_keys = frozenset(TOOL_REGISTRY.keys())
|
||||
assert PLATFORM_TOOL_NAMES == registry_keys, (
|
||||
f"ToolName Literal is out of sync with TOOL_REGISTRY. "
|
||||
f"Missing: {registry_keys - PLATFORM_TOOL_NAMES}, "
|
||||
f"Extra: {PLATFORM_TOOL_NAMES - registry_keys}"
|
||||
)
|
||||
|
||||
def test_all_tool_names_is_union(self):
|
||||
"""ALL_TOOL_NAMES must equal PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES."""
|
||||
assert ALL_TOOL_NAMES == PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_no_overlap_between_platform_and_sdk(self):
|
||||
"""Platform and SDK built-in names must not overlap."""
|
||||
assert PLATFORM_TOOL_NAMES.isdisjoint(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_known_tools_includes_registry_and_builtins(self):
|
||||
known = all_known_tool_names()
|
||||
assert "run_block" in known
|
||||
assert "Read" in known
|
||||
assert "Task" in known
|
||||
@@ -6,39 +6,24 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- File: `[report.csv](workspace://file_id#text/csv)`
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
- `/absolute/path` — local/sandbox files
|
||||
- `[start-end]` — optional 1-indexed line range
|
||||
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@ -49,21 +34,9 @@ Examples:
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only; legacy `.xls` is NOT supported). Unrecognised formats return plain string.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
@@ -81,18 +54,97 @@ that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
{{
|
||||
"files": [{{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
}}]
|
||||
}}
|
||||
```
|
||||
|
||||
### Writing large files — CRITICAL
|
||||
**Never write an entire large document in a single tool call.** When the
|
||||
content you want to write exceeds ~2000 words the tool call's output token
|
||||
limit will silently truncate the arguments, producing an empty `{{}}` input
|
||||
that fails repeatedly.
|
||||
|
||||
**Preferred: compose from file references.** If the data is already in
|
||||
files (tool outputs, workspace files), compose the report in one call
|
||||
using `@@agptfile:` references — the system expands them inline:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Research Report
|
||||
## Data from web research
|
||||
@@agptfile:/home/user/web_results.txt
|
||||
## Block execution output
|
||||
@@agptfile:workspace://<file_id>
|
||||
## Conclusion
|
||||
<brief synthesis>
|
||||
EOF
|
||||
```
|
||||
|
||||
**Fallback: write section-by-section.** When you must generate content
|
||||
from conversation context (no files to reference), split into multiple
|
||||
`bash_exec` calls — one section per call:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Section 1
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
```bash
|
||||
cat >> report.md << 'EOF'
|
||||
# Section 2
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
Use `cat >` for the first chunk and `cat >>` to append subsequent chunks.
|
||||
Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **AutoPilotBlock** (`run_block` with block_id
|
||||
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
|
||||
autopilot instance. The sub-autopilot has its own full tool set and can
|
||||
perform multi-step work autonomously.
|
||||
|
||||
- **Input**: `prompt` (required) — the task description.
|
||||
Optional: `system_context` to constrain behavior, `session_id` to
|
||||
continue a previous conversation, `max_recursion_depth` (default 3).
|
||||
- **Output**: `response` (text), `tool_calls` (list), `session_id`
|
||||
(for continuation), `conversation_history`, `token_usage`.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -105,6 +157,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -119,6 +172,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -128,17 +182,12 @@ def _build_storage_supplement(
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
### Shell & filesystem
|
||||
- The SDK built-in Bash tool is NOT available. Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
|
||||
- SDK file tools (Read/Write/Edit/Glob/Grep) and `bash_exec` share one filesystem — use relative or absolute paths under this dir.
|
||||
- `read_workspace_file`/`write_workspace_file` operate on **persistent cloud workspace storage** (separate from the working dir).
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
@@ -152,12 +201,23 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -175,7 +235,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -190,6 +254,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Single source of truth for copilot-supported integration providers.
|
||||
|
||||
Both :mod:`~backend.copilot.integration_creds` (env-var injection) and
|
||||
:mod:`~backend.copilot.tools.connect_integration` (UI setup card) import from
|
||||
here, eliminating the risk of the two registries drifting out of sync.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ProviderEntry(TypedDict):
|
||||
"""Metadata for a supported integration provider.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable display name (e.g. "GitHub").
|
||||
env_vars: Environment variable names injected when the provider is
|
||||
connected (e.g. ``["GH_TOKEN", "GITHUB_TOKEN"]``).
|
||||
default_scopes: Default OAuth scopes requested when the agent does not
|
||||
specify any.
|
||||
"""
|
||||
|
||||
name: str
|
||||
env_vars: list[str]
|
||||
default_scopes: list[str]
|
||||
|
||||
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Uses a lazy import to avoid triggering ``Secrets()`` during module import,
|
||||
which can fail in environments where secrets are not yet loaded (e.g. tests,
|
||||
CLI tooling).
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# -- Registry ----------------------------------------------------------------
|
||||
# Add new providers here. Both env-var injection and the setup-card tool read
|
||||
# from this single registry.
|
||||
|
||||
SUPPORTED_PROVIDERS: dict[str, ProviderEntry] = {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"env_vars": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
"default_scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_auth_types(provider: str) -> list[str]:
|
||||
"""Return the supported credential types for *provider* at runtime.
|
||||
|
||||
OAuth types are only offered when the corresponding OAuth client env vars
|
||||
are configured.
|
||||
"""
|
||||
if provider == "github":
|
||||
if _is_github_oauth_configured():
|
||||
return ["api_key", "oauth2"]
|
||||
return ["api_key"]
|
||||
# Default for unknown/future providers — API key only.
|
||||
return ["api_key"]
|
||||
@@ -43,6 +43,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -263,3 +264,19 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
|
||||
Sent as a proper ``data:`` event so the frontend can display it to the
|
||||
user. The AI SDK stream parser gracefully skips unknown chunk types
|
||||
(logs a console warning), so this does not break the stream.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
@@ -19,9 +19,19 @@ least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Static imports for type checkers so they can resolve __all__ entries
|
||||
# without executing the lazy-import machinery at runtime.
|
||||
if TYPE_CHECKING:
|
||||
from .collect import CopilotResult as CopilotResult
|
||||
from .collect import collect_copilot_response as collect_copilot_response
|
||||
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"CopilotResult",
|
||||
"collect_copilot_response",
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -29,6 +39,8 @@ __all__ = [
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"CopilotResult": (".collect", "CopilotResult"),
|
||||
"collect_copilot_response": (".collect", "collect_copilot_response"),
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
@@ -143,6 +143,85 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using OrchestratorBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `OrchestratorBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
- `1` for single-step tool calls (AI picks one tool, calls it, done)
|
||||
- `3`–`10` for multi-step tasks (AI calls tools iteratively)
|
||||
- `-1` for open-ended orchestration (AI loops until it decides it's done).
|
||||
**Use with caution** — prefer bounded iterations (3–10) unless
|
||||
genuinely needed, as unbounded loops risk runaway cost and execution.
|
||||
Do NOT use `0` (traditional mode) — it requires complex external
|
||||
conversation-history loop wiring that the agent generator does not
|
||||
produce.
|
||||
- `conversation_compaction`: `true` (recommended to avoid context overflow)
|
||||
- `retry`: Number of retries on tool-call failure (default `3`).
|
||||
Set to `0` to disable retries.
|
||||
- `multiple_tool_calls`: Whether the AI can invoke multiple tools in a
|
||||
single turn (default `false`). Enable when tools are independent and
|
||||
can run concurrently.
|
||||
- Optional: `sys_prompt` for extra LLM context about how to orchestrate
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the Orchestrator: set `source_name: "tools"` on
|
||||
the Orchestrator side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
after saving — do NOT require them upfront
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- Orchestrator→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- Orchestrator→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
|
||||
232
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
232
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Public helpers for consuming a copilot stream as a simple request-response.
|
||||
|
||||
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
|
||||
so that callers (e.g. the AutoPilot block) can consume the copilot stream
|
||||
without implementing their own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Identifiers used when registering AutoPilot-originated streams in the
|
||||
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
|
||||
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
|
||||
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
|
||||
AUTOPILOT_TOOL_NAME = "autopilot"
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
"""Aggregated result from consuming a copilot stream.
|
||||
|
||||
Returned by :func:`collect_copilot_response` so callers don't need to
|
||||
implement their own event-loop over the raw stream events.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"response_text",
|
||||
"tool_calls",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.response_text: str = ""
|
||||
self.tool_calls: list[dict[str, Any]] = []
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
class _RegistryHandle(BaseModel):
|
||||
"""Tracks stream registry session state for cleanup."""
|
||||
|
||||
publish_turn_id: str = ""
|
||||
error_msg: str | None = None
|
||||
error_already_published: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _registry_session(
|
||||
session_id: str, user_id: str, turn_id: str
|
||||
) -> AsyncIterator[_RegistryHandle]:
|
||||
"""Create a stream registry session and ensure it is finalized."""
|
||||
handle = _RegistryHandle(publish_turn_id=turn_id)
|
||||
try:
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=AUTOPILOT_TOOL_NAME,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to create stream registry session for %s, "
|
||||
"frontend will not receive real-time updates",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
# Disable chunk publishing but keep finalization enabled so
|
||||
# mark_session_completed can clean up any partial registry state.
|
||||
handle.publish_turn_id = ""
|
||||
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
session_id,
|
||||
error_message=handle.error_msg,
|
||||
skip_error_publish=handle.error_already_published,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to mark stream completed for %s",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallEntry(BaseModel):
|
||||
"""A single tool call observed during stream consumption."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any = None
|
||||
success: bool | None = None
|
||||
|
||||
|
||||
class _EventAccumulator(BaseModel):
|
||||
"""Mutable accumulator for stream events."""
|
||||
|
||||
response_parts: list[str] = Field(default_factory=list)
|
||||
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
|
||||
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
|
||||
"""Process a single stream event and return error_msg if StreamError.
|
||||
|
||||
Uses structural pattern matching for dispatch per project guidelines.
|
||||
"""
|
||||
match event:
|
||||
case StreamTextDelta(delta=delta):
|
||||
acc.response_parts.append(delta)
|
||||
case StreamToolInputAvailable() as e:
|
||||
entry = _ToolCallEntry(
|
||||
tool_call_id=e.toolCallId,
|
||||
tool_name=e.toolName,
|
||||
input=e.input,
|
||||
)
|
||||
acc.tool_calls.append(entry)
|
||||
acc.tool_calls_by_id[e.toolCallId] = entry
|
||||
case StreamToolOutputAvailable() as e:
|
||||
if tc := acc.tool_calls_by_id.get(e.toolCallId):
|
||||
tc.output = e.output
|
||||
tc.success = e.success
|
||||
else:
|
||||
logger.debug(
|
||||
"Received tool output for unknown tool_call_id: %s",
|
||||
e.toolCallId,
|
||||
)
|
||||
case StreamUsage() as e:
|
||||
acc.prompt_tokens += e.prompt_tokens
|
||||
acc.completion_tokens += e.completion_tokens
|
||||
acc.total_tokens += e.total_tokens
|
||||
case StreamError(errorText=err):
|
||||
return err
|
||||
return None
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
Registers with the stream registry so the frontend can connect via SSE
|
||||
and receive real-time updates while the AutoPilot block is executing.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
permissions: Optional capability filter. When provided, restricts
|
||||
which tools and blocks the copilot may use during this execution.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
tool calls, and token usage.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
turn_id = str(uuid.uuid4())
|
||||
async with _registry_session(session_id, user_id, turn_id) as handle:
|
||||
try:
|
||||
raw_stream = stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
published_stream = stream_registry.stream_and_publish(
|
||||
session_id=session_id,
|
||||
turn_id=handle.publish_turn_id,
|
||||
stream=raw_stream,
|
||||
)
|
||||
|
||||
acc = _EventAccumulator()
|
||||
async for event in published_stream:
|
||||
if err := _process_event(event, acc):
|
||||
handle.error_msg = err
|
||||
# stream_and_publish skips StreamError events, so
|
||||
# mark_session_completed must publish the error to Redis.
|
||||
handle.error_already_published = False
|
||||
raise RuntimeError(f"Copilot error: {err}")
|
||||
except Exception:
|
||||
if handle.error_msg is None:
|
||||
handle.error_msg = "AutoPilot execution failed"
|
||||
raise
|
||||
|
||||
result = CopilotResult()
|
||||
result.response_text = "".join(acc.response_parts)
|
||||
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
|
||||
result.prompt_tokens = acc.prompt_tokens
|
||||
result.completion_tokens = acc.completion_tokens
|
||||
result.total_tokens = acc.total_tokens
|
||||
return result
|
||||
177
autogpt_platform/backend/backend/copilot/sdk/collect_test.py
Normal file
177
autogpt_platform/backend/backend/copilot/sdk/collect_test.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Tests for collect_copilot_response stream registry integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
|
||||
def _mock_stream_fn(*events):
|
||||
"""Return a callable that returns an async generator."""
|
||||
|
||||
async def _gen(**_kwargs):
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
return _gen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Patch stream_registry module used by collect."""
|
||||
with patch("backend.copilot.sdk.collect.stream_registry") as m:
|
||||
m.create_session = AsyncMock()
|
||||
m.publish_chunk = AsyncMock()
|
||||
m.mark_session_completed = AsyncMock()
|
||||
|
||||
# stream_and_publish: pass-through that also publishes (real logic)
|
||||
# We re-implement the pass-through here so the event loop works,
|
||||
# but still track publish_chunk calls via the mock.
|
||||
async def _stream_and_publish(session_id, turn_id, stream):
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
await m.publish_chunk(turn_id, event)
|
||||
yield event
|
||||
|
||||
m.stream_and_publish = _stream_and_publish
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_fn_patch():
|
||||
"""Helper to patch stream_chat_completion_sdk."""
|
||||
|
||||
def _patch(events):
|
||||
return patch(
|
||||
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
|
||||
new=_mock_stream_fn(*events),
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
|
||||
"""Stream registry create/publish/complete are called correctly on success."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="Hello "),
|
||||
StreamTextDelta(id="t1", delta="world"),
|
||||
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "Hello world"
|
||||
assert result.total_tokens == 15
|
||||
|
||||
mock_registry.create_session.assert_awaited_once()
|
||||
# StreamFinish should NOT be published (mark_session_completed does it)
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamFinish" not in published_types
|
||||
assert "StreamTextDelta" in published_types
|
||||
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
|
||||
"""mark_session_completed receives error message when StreamError occurs."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="partial"),
|
||||
StreamError(errorText="something broke"),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
with pytest.raises(RuntimeError, match="something broke"):
|
||||
await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") == "something broke"
|
||||
# stream_and_publish skips StreamError, so mark_session_completed must
|
||||
# publish it (skip_error_publish=False).
|
||||
assert kwargs.get("skip_error_publish") is False
|
||||
|
||||
# StreamError should NOT be published via publish_chunk — mark_session_completed
|
||||
# handles it to avoid double-publication.
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamError" not in published_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_create_session_fails(
|
||||
mock_registry, stream_fn_patch
|
||||
):
|
||||
"""AutoPilot still works when stream registry create_session raises."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="works"),
|
||||
StreamFinish(),
|
||||
]
|
||||
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "works"
|
||||
# publish_chunk should NOT be called because turn_id was cleared
|
||||
mock_registry.publish_chunk.assert_not_awaited()
|
||||
# mark_session_completed IS still called to clean up any partial state
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
|
||||
"""Tool call events are both published to registry and collected in result."""
|
||||
events = [
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
|
||||
),
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId="tc-1", output="file contents", success=True
|
||||
),
|
||||
StreamTextDelta(id="t1", delta="done"),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["tool_name"] == "read_file"
|
||||
assert result.tool_calls[0]["output"] == "file contents"
|
||||
assert result.tool_calls[0]["success"] is True
|
||||
assert result.response_text == "done"
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -119,14 +120,12 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -222,6 +221,7 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type("Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"})(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
Enable via CHAT_TEST_MODE=true in .env (ChatConfig.test_mode).
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
|
||||
Magic keywords (case-insensitive, anywhere in message):
|
||||
__test_transient_error__ — Simulate a transient Anthropic API error
|
||||
(ECONNRESET). Streams partial text, then
|
||||
yields StreamError with retryable prefix.
|
||||
__test_fatal_error__ — Simulate a non-retryable SDK error.
|
||||
__test_slow_response__ — Simulate a slow response (2s per word).
|
||||
(no keyword) — Normal dummy response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -12,12 +20,39 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
)
|
||||
from ..model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _safe_upsert(session: ChatSession) -> None:
|
||||
"""Best-effort session persist — skip silently if DB is unavailable."""
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not persist session (DB unavailable)")
|
||||
|
||||
|
||||
def _has_keyword(message: str | None, keyword: str) -> bool:
|
||||
return keyword in (message or "").lower()
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -36,24 +71,89 @@ async def stream_chat_completion_dummy(
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
|
||||
See module docstring for magic keywords that trigger error scenarios.
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
# Load session from DB (matches SDK service behaviour) so error markers
|
||||
# and the assistant reply are persisted and survive page refresh.
|
||||
# Best-effort: skip if DB is unavailable (e.g. unit tests).
|
||||
if session is None:
|
||||
try:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not load session (DB unavailable)")
|
||||
session = None
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
# Start the stream (matches baseline: StreamStart → StreamStartStep)
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
yield StreamStartStep()
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
# --- Magic keyword: transient error (retryable) -------------------------
|
||||
if _has_keyword(message, "__test_transient_error__"):
|
||||
# Stream some partial text first (simulates mid-stream failure)
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for word in ["Working", "on", "it..."]:
|
||||
yield StreamTextDelta(id=text_block_id, delta=f"{word} ")
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
yield StreamFinishStep()
|
||||
# Persist retryable marker so "Try Again" button shows after refresh
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} {FRIENDLY_TRANSIENT_MSG}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
return
|
||||
|
||||
# --- Magic keyword: fatal error (non-retryable) -------------------------
|
||||
if _has_keyword(message, "__test_fatal_error__"):
|
||||
yield StreamFinishStep()
|
||||
error_msg = "Internal SDK error: model refused to respond"
|
||||
# Persist non-retryable error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_ERROR_PREFIX} {error_msg}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(errorText=error_msg, code="sdk_error")
|
||||
return
|
||||
|
||||
# --- Magic keyword: slow response ---------------------------------------
|
||||
delay = 2.0 if _has_keyword(message, "__test_slow_response__") else 0.1
|
||||
|
||||
# --- Normal dummy response ----------------------------------------------
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(delay)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Persist the assistant reply so it survives page refresh
|
||||
if session:
|
||||
session.messages.append(ChatMessage(role="assistant", content=dummy_response))
|
||||
await _safe_upsert(session)
|
||||
|
||||
yield StreamFinishStep()
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
and ``/tmp`` filesystems as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
@@ -16,16 +16,51 @@ import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_ALLOWED_DIRS,
|
||||
E2B_ALLOWED_DIRS_STR,
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
is_within_allowed_dirs,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
the allowed sandbox directories.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or not is_within_allowed_dirs(canonical_parent)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -54,6 +89,38 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
(``sandbox.commands.run``) run as **different users**. On ``/tmp``
|
||||
(which has the sticky bit set) this means ``sandbox.files.write`` can
|
||||
create new files but cannot overwrite files previously created by
|
||||
``sandbox.commands.run`` (or itself), because the sticky bit restricts
|
||||
deletion/rename to the file owner.
|
||||
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=10,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"shell write failed (exit {result.exit_code}): "
|
||||
+ (result.stderr or "").strip()
|
||||
)
|
||||
else:
|
||||
await sandbox.files.write(path, content)
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
@@ -104,9 +171,16 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
if parent and parent not in E2B_ALLOWED_DIRS:
|
||||
await sandbox.files.make_dir(parent)
|
||||
await sandbox.files.write(remote, content)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await _sandbox_write(sandbox, remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -130,6 +204,15 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
@@ -152,7 +235,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
await _sandbox_write(sandbox, remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -245,14 +328,14 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Read a file from the cloud sandbox (/home/user or /tmp). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
@@ -269,7 +352,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Write or create a file in the cloud sandbox (/home/user or /tmp). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
@@ -278,7 +361,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
@@ -295,7 +378,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
|
||||
@@ -4,15 +4,20 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -21,46 +26,66 @@ _SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
|
||||
def test_tmp_absolute_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
def test_tmp_nested_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
def test_tmp_itself_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
def test_tmp_escape_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
def test_tmp_prefix_collision_blocked(self):
|
||||
"""A path like /tmp_evil should be blocked (not a prefix match)."""
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,9 +98,13 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -107,7 +136,9 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -117,7 +148,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -152,3 +183,155 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_allowed(self):
|
||||
"""Paths resolving to /tmp are allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp/workdir\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp/workdir")
|
||||
assert result == "/tmp/workdir"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_itself_allowed(self):
|
||||
"""The /tmp directory itself is allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp")
|
||||
assert result == "/tmp"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sandbox_write — routing writes through shell for /tmp paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSandboxWrite:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_uses_shell_command(self):
|
||||
"""Writes to /tmp should use commands.run (shell) instead of files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/test.py", "print('hello')")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_user_path_uses_files_api(self):
|
||||
"""Writes to /home/user should use sandbox.files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/home/user/test.py", "print('hello')")
|
||||
|
||||
files.write.assert_called_once_with("/home/user/test.py", "print('hello')")
|
||||
commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_nested_path_uses_shell_command(self):
|
||||
"""Writes to nested /tmp paths should use commands.run."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/subdir/file.txt", "content")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_shell_failure_raises(self):
|
||||
"""Shell write failure should raise RuntimeError."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="No space left", exit_code=1)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
with pytest.raises(RuntimeError, match="shell write failed"):
|
||||
await _sandbox_write(sandbox, "/tmp/test.txt", "content")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_preserves_content_with_special_chars(self):
|
||||
"""Content with special shell characters should be preserved via base64."""
|
||||
import base64
|
||||
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
content = "print(\"Hello $USER\")\n# a `backtick` and 'quotes'\n"
|
||||
await _sandbox_write(sandbox, "/tmp/special.py", content)
|
||||
|
||||
# Verify the command contains base64-encoded content
|
||||
call_args = commands.run.call_args[0][0]
|
||||
# Extract the base64 string from the command
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self, mock_chat_config):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self, mock_chat_config):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self, mock_chat_config):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression timeout fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompressionTimeout:
|
||||
"""Verify _run_compression falls back to truncation when LLM times out."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_falls_back_to_truncation(self):
|
||||
"""When compress_context with LLM client times out,
|
||||
_run_compression falls back to truncation (client=None)."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
call_args: list[dict] = []
|
||||
|
||||
async def _mock_compress(**kwargs):
|
||||
call_args.append(kwargs)
|
||||
if kwargs.get("client") is not None:
|
||||
# Simulate timeout by raising asyncio.TimeoutError
|
||||
raise asyncio.TimeoutError("LLM compaction timed out")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
# Should have been called twice: once with client, once without
|
||||
assert len(call_args) == 2
|
||||
assert call_args[0]["client"] is not None # LLM attempt
|
||||
assert call_args[1]["client"] is None # truncation fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation_directly(self):
|
||||
"""When no OpenAI client is configured, goes straight to truncation."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
mock_compress.assert_called_once()
|
||||
# When no client, compress_context is called with client=None
|
||||
assert mock_compress.call_args.kwargs.get("client") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _friendly_error_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFriendlyErrorText:
|
||||
"""Verify user-friendly error message mapping."""
|
||||
|
||||
def test_authentication_error(self):
|
||||
result = _friendly_error_text("authentication failed: invalid API key")
|
||||
assert "Authentication" in result
|
||||
assert "API key" in result
|
||||
|
||||
def test_rate_limit_error(self):
|
||||
result = _friendly_error_text("rate limit exceeded")
|
||||
assert "Rate limit" in result
|
||||
|
||||
def test_overloaded_error(self):
|
||||
result = _friendly_error_text("API is overloaded")
|
||||
assert "overloaded" in result
|
||||
|
||||
def test_timeout_error(self):
|
||||
result = _friendly_error_text("Request timeout after 30s")
|
||||
assert "timed out" in result
|
||||
|
||||
def test_connection_error(self):
|
||||
result = _friendly_error_text("Connection refused")
|
||||
assert "Connection" in result or "connection" in result
|
||||
|
||||
def test_unknown_error_passthrough(self):
|
||||
result = _friendly_error_text("some unknown error XYZ")
|
||||
assert "SDK stream error:" in result
|
||||
assert "XYZ" in result
|
||||
|
||||
def test_unauthorized_error(self):
|
||||
result = _friendly_error_text("401 Unauthorized")
|
||||
assert "Authentication" in result
|
||||
@@ -20,6 +20,7 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG, is_transient_api_error
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -214,10 +215,12 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
raw_error = str(sdk_message.result or "Unknown error")
|
||||
if is_transient_api_error(raw_error):
|
||||
error_text, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
error_text, code = raw_error, "sdk_error"
|
||||
responses.append(StreamError(errorText=error_text, code=code))
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -302,7 +302,11 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,337 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
``ReducedContext``, and the ``is_parallel_continuation`` logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_parallel_continuation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsParallelContinuation:
|
||||
"""Unit tests for the is_parallel_continuation expression in the streaming loop.
|
||||
|
||||
Verifies the vacuous-truth guard (empty content must return False) and the
|
||||
boundary cases for mixed TextBlock+ToolUseBlock messages.
|
||||
"""
|
||||
|
||||
def _make_tool_block(self) -> MagicMock:
|
||||
block = MagicMock(spec=ToolUseBlock)
|
||||
return block
|
||||
|
||||
def test_all_tool_use_blocks_is_parallel(self):
|
||||
"""AssistantMessage with only ToolUseBlocks is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block(), self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
def test_empty_content_is_not_parallel(self):
|
||||
"""AssistantMessage with empty content must NOT be treated as parallel.
|
||||
|
||||
Without the bool(sdk_msg.content) guard, all() on an empty iterable
|
||||
returns True via vacuous truth — this test ensures the guard is present.
|
||||
"""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = []
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_mixed_text_and_tool_blocks_not_parallel(self):
|
||||
"""AssistantMessage with text + tool blocks is NOT a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
text_block = MagicMock(spec=TextBlock)
|
||||
msg.content = [text_block, self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_non_assistant_message_not_parallel(self):
|
||||
"""Non-AssistantMessage types are never parallel continuations."""
|
||||
assert _is_tool_only_message("not a message") is False
|
||||
assert _is_tool_only_message(None) is False
|
||||
assert _is_tool_only_message(42) is False
|
||||
|
||||
def test_single_tool_block_is_parallel(self):
|
||||
"""Single ToolUseBlock AssistantMessage is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -288,3 +288,214 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupSdkToolResults:
|
||||
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
|
||||
|
||||
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
|
||||
_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_cwd_directory(self):
|
||||
"""Cleanup removes the session working directory."""
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-cleanup-remove"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
assert not os.path.exists(cwd)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_runs_when_interval_elapsed(self):
|
||||
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-elapsed"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to a time far in the past
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_skipped_within_interval(self):
|
||||
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
|
||||
import time
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-ratelimit"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to now — interval not elapsed
|
||||
svc_mod._last_sweep_time = time.time()
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_outside_prefix(self, tmp_path):
|
||||
"""Cleanup rejects a cwd that does not start with the expected prefix."""
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
evil_cwd = str(tmp_path / "evil-path")
|
||||
os.makedirs(evil_cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
await _cleanup_sdk_tool_results(evil_cwd)
|
||||
|
||||
# Directory should NOT have been removed (rejected early)
|
||||
assert os.path.exists(evil_cwd)
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ChatConfig validators read — must be cleared so explicit
|
||||
# constructor values are used.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, model keeps dot-separated version."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4.6"
|
||||
|
||||
def test_openrouter_disabled_normalizes_to_hyphens(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is disabled, dots are replaced with hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_openrouter_enabled_but_missing_key_normalizes(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is enabled but api_key is missing, falls back to
|
||||
direct Anthropic and normalizes dots to hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_explicit_claude_agent_model_takes_precedence(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When claude_agent_model is explicitly set, it is returned as-is."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model="claude-sonnet-4-5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_subscription_mode_returns_none(self, monkeypatch, _clean_config_env):
|
||||
"""When using Claude Code subscription, returns None (CLI picks model)."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() is None
|
||||
|
||||
def test_model_without_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
"""When model has no provider prefix, it still normalizes correctly."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Claude Code subscription auth helpers.
|
||||
|
||||
Handles locating the SDK-bundled CLI binary, provisioning credentials from
|
||||
environment variables, and validating that subscription auth is functional.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_bundled_cli() -> str:
|
||||
"""Locate the Claude CLI binary bundled inside ``claude_agent_sdk``.
|
||||
|
||||
Falls back to ``shutil.which("claude")`` if the SDK bundle is absent.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if path:
|
||||
return str(path)
|
||||
except Exception:
|
||||
pass
|
||||
system_path = shutil.which("claude")
|
||||
if system_path:
|
||||
return system_path
|
||||
raise RuntimeError(
|
||||
"Claude CLI not found — neither the SDK-bundled binary nor a "
|
||||
"system-installed `claude` could be located."
|
||||
)
|
||||
|
||||
|
||||
def provision_credentials_file() -> None:
|
||||
"""Write ``~/.claude/.credentials.json`` from env when running headless.
|
||||
|
||||
If ``CLAUDE_CODE_OAUTH_TOKEN`` is set (an OAuth *access* token obtained
|
||||
from ``claude auth status`` or extracted from the macOS keychain), this
|
||||
helper writes a minimal credentials file so the bundled CLI can
|
||||
authenticate without an interactive ``claude login``.
|
||||
|
||||
A ``CLAUDE_CODE_REFRESH_TOKEN`` env var is optional but recommended —
|
||||
it lets the CLI silently refresh an expired access token.
|
||||
"""
|
||||
access_token = os.environ.get("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if not access_token:
|
||||
return
|
||||
|
||||
creds_dir = os.path.expanduser("~/.claude")
|
||||
creds_path = os.path.join(creds_dir, ".credentials.json")
|
||||
|
||||
# Don't overwrite an existing credentials file (e.g. from a volume mount).
|
||||
if os.path.exists(creds_path):
|
||||
logger.debug("Credentials file already exists at %s — skipping", creds_path)
|
||||
return
|
||||
|
||||
os.makedirs(creds_dir, exist_ok=True)
|
||||
|
||||
creds = {
|
||||
"claudeAiOauth": {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": os.environ.get("CLAUDE_CODE_REFRESH_TOKEN", "").strip(),
|
||||
"expiresAt": 0,
|
||||
"scopes": [
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"user:sessions:claude_code",
|
||||
],
|
||||
}
|
||||
}
|
||||
with open(creds_path, "w") as f:
|
||||
json.dump(creds, f)
|
||||
logger.info("Provisioned Claude credentials file at %s", creds_path)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def validate_subscription() -> None:
|
||||
"""Validate the bundled Claude CLI is reachable and authenticated.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. On first call, also provisions ``~/.claude/.credentials.json``
|
||||
from the ``CLAUDE_CODE_OAUTH_TOKEN`` env var when available.
|
||||
"""
|
||||
provision_credentials_file()
|
||||
|
||||
cli = find_bundled_cli()
|
||||
result = subprocess.run(
|
||||
[cli, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
# Verify the CLI is actually authenticated.
|
||||
auth_result = subprocess.run(
|
||||
[cli, "auth", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env={
|
||||
**os.environ,
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
},
|
||||
)
|
||||
if auth_result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Claude CLI is not authenticated. Either:\n"
|
||||
" • Set CLAUDE_CODE_OAUTH_TOKEN env var (from `claude auth status` "
|
||||
"or macOS keychain), or\n"
|
||||
" • Mount ~/.claude/.credentials.json into the container, or\n"
|
||||
" • Run `claude login` inside the container."
|
||||
)
|
||||
try:
|
||||
status = json.loads(auth_result.stdout)
|
||||
if not status.get("loggedIn"):
|
||||
raise RuntimeError(
|
||||
"Claude CLI reports loggedIn=false. Set CLAUDE_CODE_OAUTH_TOKEN "
|
||||
"or run `claude login`."
|
||||
)
|
||||
logger.info(
|
||||
"Claude subscription auth: method=%s, email=%s",
|
||||
status.get("authMethod"),
|
||||
status.get("email"),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Could not parse `claude auth status` output")
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for the tool call circuit breaker in tool_adapter.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES,
|
||||
_check_circuit_breaker,
|
||||
_clear_tool_failures,
|
||||
_consecutive_tool_failures,
|
||||
_record_tool_failure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tracker():
|
||||
"""Reset the circuit breaker tracker for each test."""
|
||||
token = _consecutive_tool_failures.set({})
|
||||
yield
|
||||
_consecutive_tool_failures.reset(token)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
def test_no_trip_below_threshold(self):
|
||||
"""Circuit breaker should not trip before reaching the limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Still under the limit
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_trips_at_threshold(self):
|
||||
"""Circuit breaker should trip after reaching the failure limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Now it should trip
|
||||
result = _check_circuit_breaker("write_file", args)
|
||||
assert result is not None
|
||||
assert "STOP" in result
|
||||
assert "write_file" in result
|
||||
|
||||
def test_different_args_tracked_separately(self):
|
||||
"""Different args should have separate failure counters."""
|
||||
args_a = {"file_path": "/tmp/a.txt"}
|
||||
args_b = {"file_path": "/tmp/b.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args_a)
|
||||
# args_a should trip
|
||||
assert _check_circuit_breaker("write_file", args_a) is not None
|
||||
# args_b should NOT trip
|
||||
assert _check_circuit_breaker("write_file", args_b) is None
|
||||
|
||||
def test_different_tools_tracked_separately(self):
|
||||
"""Different tools should have separate failure counters."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("tool_a", args)
|
||||
# tool_a should trip
|
||||
assert _check_circuit_breaker("tool_a", args) is not None
|
||||
# tool_b with same args should NOT trip
|
||||
assert _check_circuit_breaker("tool_b", args) is None
|
||||
|
||||
def test_empty_args_tracked(self):
|
||||
"""Empty args ({}) — the exact failure pattern from the bug — should be tracked."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is not None
|
||||
|
||||
def test_clear_resets_counter(self):
|
||||
"""Clearing failures should reset the counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
_clear_tool_failures("write_file")
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_success_clears_failures(self):
|
||||
"""A successful call should reset the failure counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
# Success clears failures
|
||||
_clear_tool_failures("write_file")
|
||||
# Should be able to fail again without tripping
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_no_tracker_returns_none(self):
|
||||
"""If tracker is not initialized, circuit breaker should not trip."""
|
||||
_consecutive_tool_failures.set(None) # type: ignore[arg-type]
|
||||
_record_tool_failure("write_file", {}) # should not raise
|
||||
assert _check_circuit_breaker("write_file", {}) is None
|
||||
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
@@ -41,6 +42,8 @@ from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
@@ -50,6 +53,14 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -66,12 +77,23 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Circuit breaker: tracks consecutive tool failures to detect infinite retry loops.
|
||||
# When a tool is called repeatedly with empty/identical args and keeps failing,
|
||||
# this counter is incremented. After _MAX_CONSECUTIVE_TOOL_FAILURES identical
|
||||
# failures the tool handler returns a hard-stop message instead of the raw error.
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES = 3
|
||||
_consecutive_tool_failures: ContextVar[dict[str, int]] = ContextVar(
|
||||
"_consecutive_tool_failures",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -83,14 +105,83 @@ def set_execution_context(
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def reset_stash_event() -> None:
|
||||
"""Clear any stale stash signal left over from a previous stream attempt.
|
||||
|
||||
``_stash_event`` is set once per session in ``set_execution_context`` and
|
||||
reused across retry attempts. A PostToolUse hook from a failed attempt may
|
||||
leave the event set; calling this at the start of each retry prevents
|
||||
``wait_for_stash`` from returning prematurely on a stale signal.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
Called at the start of each SDK retry attempt so that failure counts
|
||||
from a previous (rolled-back) attempt do not carry over and prematurely
|
||||
trip the breaker on a fresh attempt with different context.
|
||||
"""
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
@@ -146,7 +237,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
@@ -155,12 +246,13 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -169,7 +261,7 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
# Slow path: block until the hook signals or the safety timeout expires.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
@@ -179,6 +271,82 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
@@ -187,8 +355,10 @@ async def _execute_tool_sync(
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
@@ -217,6 +387,66 @@ def _mcp_error(message: str) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _failure_key(tool_name: str, args: dict[str, Any]) -> str:
|
||||
"""Compute a stable fingerprint for (tool_name, args) used by the circuit breaker."""
|
||||
args_key = json.dumps(args, sort_keys=True, default=str)
|
||||
return f"{tool_name}:{args_key}"
|
||||
|
||||
|
||||
def _check_circuit_breaker(tool_name: str, args: dict[str, Any]) -> str | None:
|
||||
"""Check if a tool has hit the consecutive failure limit.
|
||||
|
||||
Tracks failures keyed by (tool_name, args_fingerprint). Returns an error
|
||||
message if the circuit breaker has tripped, or None if the call should proceed.
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return None
|
||||
|
||||
key = _failure_key(tool_name, args)
|
||||
count = tracker.get(key, 0)
|
||||
if count >= _MAX_CONSECUTIVE_TOOL_FAILURES:
|
||||
logger.warning(
|
||||
"Circuit breaker tripped for tool %s after %d consecutive "
|
||||
"identical failures (args=%s)",
|
||||
tool_name,
|
||||
count,
|
||||
key[len(tool_name) + 1 :][:200],
|
||||
)
|
||||
return (
|
||||
f"STOP: Tool '{tool_name}' has failed {count} consecutive times with "
|
||||
f"the same arguments. Do NOT retry this tool call. "
|
||||
f"If you were trying to write content to a file, instead respond with "
|
||||
f"the content directly as a text message to the user."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _record_tool_failure(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Record a tool failure for circuit breaker tracking."""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
key = _failure_key(tool_name, args)
|
||||
tracker[key] = tracker.get(key, 0) + 1
|
||||
|
||||
|
||||
def _clear_tool_failures(tool_name: str) -> None:
|
||||
"""Clear failure tracking for a tool on success.
|
||||
|
||||
Clears ALL args variants for the tool, not just the successful call's args.
|
||||
This gives the tool a "fresh start" on any success, which is appropriate for
|
||||
the primary use case (detecting infinite loops with identical failing args).
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
# Clear all entries for this tool name
|
||||
keys_to_remove = [k for k in tracker if k.startswith(f"{tool_name}:")]
|
||||
for k in keys_to_remove:
|
||||
del tracker[k]
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
@@ -225,7 +455,83 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -234,8 +540,12 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. Check server logs for details."
|
||||
)
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -285,7 +595,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
with open(resolved, encoding="utf-8", errors="replace") as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
@@ -358,6 +668,15 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
# expansion mutates args, which would cause a key mismatch.
|
||||
original_args = args
|
||||
stop_msg = _check_circuit_breaker(tool_name, original_args)
|
||||
if stop_msg:
|
||||
return _mcp_error(stop_msg)
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
@@ -365,6 +684,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
@@ -374,6 +694,12 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Track consecutive failures for circuit breaker
|
||||
if truncated.get("isError"):
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
else:
|
||||
_clear_tool_failures(tool_name)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
wait_for_stash,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -120,6 +130,69 @@ class TestToolOutputStash:
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_stash_event / wait_for_stash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetStashEvent:
|
||||
"""Tests for reset_stash_event — the stale-signal fix for retry attempts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_stale_signal(self):
|
||||
"""After reset, wait_for_stash does NOT return immediately (blocks until timeout)."""
|
||||
# Simulate a stale signal left by a failed attempt's PostToolUse hook.
|
||||
stash_pending_tool_output("some_tool", "stale output")
|
||||
# The stash_pending_tool_output call sets the event.
|
||||
# Now reset it — simulating start of a new retry attempt.
|
||||
reset_stash_event()
|
||||
# wait_for_stash should block and time out since the event was cleared.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"wait_for_stash should have timed out after reset_stash_event, "
|
||||
"but it returned True — stale signal was not cleared"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_true_when_signaled_after_reset(self):
|
||||
"""After reset, a new stash signal is correctly detected."""
|
||||
reset_stash_event()
|
||||
|
||||
async def _signal_after_delay():
|
||||
await asyncio.sleep(0.01)
|
||||
stash_pending_tool_output("tool", "fresh output")
|
||||
|
||||
asyncio.create_task(_signal_after_delay())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_scenario_stale_event_does_not_fire_prematurely(self):
|
||||
"""Simulates: attempt 1 leaves event set → reset → attempt 2 waits correctly."""
|
||||
# Attempt 1: hook fires and sets the event
|
||||
stash_pending_tool_output("t", "attempt-1-output")
|
||||
# Pop it so the stash is empty (simulating normal consumption)
|
||||
pop_pending_tool_output("t")
|
||||
|
||||
# Between attempts: reset (as service.py does before each retry)
|
||||
reset_stash_event()
|
||||
|
||||
# Attempt 2: wait_for_stash should NOT return True immediately
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"Stale event from attempt 1 caused wait_for_stash to return "
|
||||
"prematurely in attempt 2"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,3 +241,534 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel pre-launch infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
|
||||
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.parameters = {"properties": {}, "required": []}
|
||||
tool.execute = AsyncMock(
|
||||
return_value=StreamToolOutputAvailable(
|
||||
toolCallId="test-id",
|
||||
output=output,
|
||||
toolName=name,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _make_mock_session() -> MagicMock:
|
||||
"""Return a minimal ChatSession mock."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def _init_ctx(session=None):
|
||||
set_execution_context(
|
||||
user_id="user-1",
|
||||
session=session, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "direct result" in text
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
|
||||
@@ -10,6 +10,9 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -17,8 +20,12 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +106,14 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -151,44 +165,110 @@ def _projects_base() -> str:
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
|
||||
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
|
||||
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
|
||||
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
|
||||
|
||||
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
|
||||
directory. These are intentionally kept across turns so the model can read
|
||||
tool-result files via ``--resume``. However, after a session ends they
|
||||
become stale. This function sweeps old ones to prevent unbounded disk
|
||||
growth.
|
||||
|
||||
When *encoded_cwd* is provided the sweep is scoped to that single
|
||||
directory, making the operation safe in multi-tenant environments where
|
||||
multiple copilot sessions share the same host. Without it the function
|
||||
falls back to sweeping all directories matching the copilot naming pattern
|
||||
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
projects_base = _projects_base()
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
if not os.path.isdir(projects_base):
|
||||
return 0
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] Project dir escaped projects base: %s", project_dir
|
||||
)
|
||||
return None
|
||||
return project_dir
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
|
||||
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
|
||||
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
|
||||
try:
|
||||
resolved_base = Path(project_dir).resolve()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
|
||||
return []
|
||||
|
||||
result: list[Path] = []
|
||||
for candidate in Path(project_dir).glob("*.jsonl"):
|
||||
try:
|
||||
resolved = candidate.resolve()
|
||||
if resolved.is_relative_to(resolved_base):
|
||||
result.append(resolved)
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug(
|
||||
"[Transcript] Skipping invalid CLI session candidate %s: %s",
|
||||
candidate,
|
||||
e,
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
if "-tmp-copilot-" not in target.name:
|
||||
logger.warning(
|
||||
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
|
||||
)
|
||||
return result
|
||||
return 0
|
||||
try:
|
||||
# st_mtime is used as a proxy for session activity. Claude CLI writes
|
||||
# its JSONL transcript into this directory during each turn, so mtime
|
||||
# advances on every turn. A directory whose mtime is older than
|
||||
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
|
||||
# and is safe to remove (the session cannot --resume after cleanup).
|
||||
age = now - target.stat().st_mtime
|
||||
except OSError:
|
||||
return 0
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
return 0
|
||||
try:
|
||||
shutil.rmtree(target, ignore_errors=True)
|
||||
removed = 1
|
||||
except OSError:
|
||||
pass
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
|
||||
target.name,
|
||||
int(age),
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
# Unscoped fallback: sweep all copilot dirs across the projects base.
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
|
||||
for entry in entries:
|
||||
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
|
||||
break
|
||||
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
|
||||
# -private-tmp-copilot-).
|
||||
if "-tmp-copilot-" not in entry.name:
|
||||
continue
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
try:
|
||||
# See the scoped-mode comment above: st_mtime advances on every turn,
|
||||
# so a stale mtime reliably indicates an inactive session.
|
||||
age = now - entry.stat().st_mtime
|
||||
except OSError:
|
||||
continue
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
|
||||
removed,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
|
||||
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
@@ -255,63 +335,6 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
return entries
|
||||
|
||||
|
||||
def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any compaction.
|
||||
|
||||
The CLI writes its session transcript to
|
||||
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
|
||||
Since each SDK turn uses a unique ``sdk_cwd``, there should be
|
||||
exactly one ``.jsonl`` file in that directory.
|
||||
|
||||
Returns the file content, or ``None`` if not found.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
|
||||
jsonl_files = _safe_glob_jsonl(project_dir)
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
|
||||
return None
|
||||
|
||||
# Pick the most recently modified file (should be only one per turn).
|
||||
try:
|
||||
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
content = session_file.read_text()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
session_file,
|
||||
len(content),
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -327,7 +350,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -337,17 +360,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -408,8 +431,6 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -448,17 +469,15 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -494,11 +513,14 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -512,8 +534,6 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -521,10 +541,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -536,10 +556,14 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, Exception):
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -553,8 +577,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -571,3 +593,280 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,9 +10,7 @@ from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
delete_transcript,
|
||||
read_cli_session_file,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
@@ -292,85 +291,6 @@ class TestStripProgressEntries:
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns None when no .jsonl files exist."""
|
||||
# Create a project dir with no jsonl files
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
assert read_cli_session_file("/fake/cwd") is None
|
||||
|
||||
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns the content of a single .jsonl file."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
jsonl_file = project_dir / "session.jsonl"
|
||||
jsonl_file.write_text("line1\nline2\n")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result == "line1\nline2\n"
|
||||
|
||||
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file skips symlinks that escape the project dir."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
|
||||
# Create a file outside the project dir
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
outside_file = outside / "evil.jsonl"
|
||||
outside_file.write_text("should not be read\n")
|
||||
|
||||
# Symlink from inside project_dir to outside file
|
||||
symlink = project_dir / "evil.jsonl"
|
||||
symlink.symlink_to(outside_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
# The symlink target resolves outside project_dir, so it should be skipped
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
projects_dir = config_dir / "projects"
|
||||
projects_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# Create a symlink inside projects/ that points outside of it.
|
||||
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
|
||||
# cwd whose encoded form matches the symlink name we create.
|
||||
evil_target = tmp_path / "escaped"
|
||||
evil_target.mkdir()
|
||||
|
||||
# The encoded form of "/evil/cwd" is "-evil-cwd"
|
||||
symlink_path = projects_dir / "-evil-cwd"
|
||||
symlink_path.symlink_to(evil_target)
|
||||
|
||||
result = _cli_project_dir("/evil/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- delete_transcript ---
|
||||
|
||||
|
||||
@@ -382,7 +302,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -402,7 +322,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -420,7 +340,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -897,3 +817,386 @@ class TestCompactionFlowIntegration:
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
)
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(*, messages, model, client):
|
||||
call_count[0] += 1
|
||||
if client is not None:
|
||||
# Simulate a hang that exceeds the timeout
|
||||
await asyncio.sleep(9999)
|
||||
return truncation_result
|
||||
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (times out), once truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup_stale_project_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupStaleProjectDirs:
|
||||
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Create a stale dir
|
||||
stale = projects_dir / "-tmp-copilot-old-session"
|
||||
stale.mkdir()
|
||||
# Set mtime to past the threshold
|
||||
import time
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale, (old_time, old_time))
|
||||
|
||||
# Create a fresh dir
|
||||
fresh = projects_dir / "-tmp-copilot-new-session"
|
||||
fresh.mkdir()
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 1
|
||||
assert not stale.exists()
|
||||
assert fresh.exists()
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Non-copilot dir that's old
|
||||
import time
|
||||
|
||||
other = projects_dir / "some-other-project"
|
||||
other.mkdir()
|
||||
old_time = time.time() - 999999
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert other.exists()
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
|
||||
boundary = projects_dir / "-tmp-copilot-boundary"
|
||||
boundary.mkdir()
|
||||
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
|
||||
os.utime(boundary, (boundary_time, boundary_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert boundary.exists()
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Create a regular FILE (not a dir) with the copilot pattern name
|
||||
stale_file = projects_dir / "-tmp-copilot-stale-file"
|
||||
stale_file.write_text("not a dir")
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale_file, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert stale_file.exists()
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
|
||||
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
|
||||
# Two stale copilot dirs
|
||||
target = projects_dir / "-tmp-copilot-session-abc"
|
||||
target.mkdir()
|
||||
os.utime(target, (old_time, old_time))
|
||||
|
||||
other = projects_dir / "-tmp-copilot-session-xyz"
|
||||
other.mkdir()
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
# Only the target dir should be removed
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
|
||||
assert removed == 1
|
||||
assert not target.exists()
|
||||
assert other.exists() # untouched — not the current session
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
fresh = projects_dir / "-tmp-copilot-session-new"
|
||||
fresh.mkdir()
|
||||
# mtime is now — well within TTL
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
|
||||
assert removed == 0
|
||||
assert fresh.exists()
|
||||
|
||||
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
non_copilot = projects_dir / "some-other-project"
|
||||
non_copilot.mkdir()
|
||||
os.utime(non_copilot, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
@@ -17,11 +17,13 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
@@ -33,12 +35,21 @@ from backend.data.redis_client import get_redis_async
|
||||
from .config import ChatConfig
|
||||
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -280,6 +291,56 @@ async def publish_chunk(
|
||||
return message_id
|
||||
|
||||
|
||||
async def stream_and_publish(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
stream: AsyncIterator[StreamBaseResponse],
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Wrap an async stream iterator with registry publishing.
|
||||
|
||||
Publishes each chunk to the stream registry for frontend SSE consumption,
|
||||
skipping ``StreamFinish`` and ``StreamError`` (which are published by
|
||||
:func:`mark_session_completed`).
|
||||
|
||||
This is a pass-through: every event from *stream* is yielded unchanged so
|
||||
the caller can still consume and aggregate them. The caller is responsible
|
||||
for calling :func:`create_session` before and :func:`mark_session_completed`
|
||||
after iterating.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID (for logging only).
|
||||
turn_id: Turn UUID that identifies the Redis stream to publish to.
|
||||
If empty, publishing is silently skipped (graceful degradation).
|
||||
stream: The underlying async iterator of stream events.
|
||||
|
||||
Yields:
|
||||
Every event from *stream*, unchanged.
|
||||
"""
|
||||
publish_failed_once = False
|
||||
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s "
|
||||
"(further failures logged at DEBUG)",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[stream_and_publish] Failed to publish chunk %s",
|
||||
type(event).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
yield event
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
@@ -693,6 +754,8 @@ async def _stream_listener(
|
||||
async def mark_session_completed(
|
||||
session_id: str,
|
||||
error_message: str | None = None,
|
||||
*,
|
||||
skip_error_publish: bool = False,
|
||||
) -> bool:
|
||||
"""Mark a session as completed, then publish StreamFinish.
|
||||
|
||||
@@ -708,6 +771,10 @@ async def mark_session_completed(
|
||||
session_id: Session ID to mark as completed
|
||||
error_message: If provided, marks as "failed" and publishes a
|
||||
StreamError before StreamFinish. Otherwise marks as "completed".
|
||||
skip_error_publish: If True, still marks the session as "failed" but
|
||||
does NOT publish a StreamError event. Use this when the error has
|
||||
already been published to the stream (e.g. via stream_and_publish)
|
||||
to avoid duplicate error delivery to the frontend.
|
||||
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
@@ -727,7 +794,7 @@ async def mark_session_completed(
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
if error_message:
|
||||
if error_message and not skip_error_publish:
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
@@ -913,21 +980,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextEnd,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
|
||||
@@ -4,11 +4,12 @@ These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
Enable test mode with CHAT_TEST_MODE=true environment variable (or in .env).
|
||||
|
||||
Note: StreamFinish is NOT emitted by the dummy service — it is published
|
||||
by mark_session_completed in the processor layer. These tests only cover
|
||||
the service-level streaming output (StreamStart + StreamTextDelta).
|
||||
The dummy service emits the full AI SDK protocol event sequence:
|
||||
StreamStart → StreamStartStep → StreamTextStart → StreamTextDelta(s) →
|
||||
StreamTextEnd → StreamFinishStep → StreamFinish.
|
||||
The processor skips StreamFinish and publishes its own via mark_session_completed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,9 +21,14 @@ import pytest
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
@@ -30,9 +36,9 @@ from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
os.environ["CHAT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
os.environ.pop("CHAT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -110,9 +116,14 @@ async def test_streaming_event_types():
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types (StreamFinish is published by processor, not service)
|
||||
# Required event types for full AI SDK protocol
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamStartStep" in event_types, "Missing StreamStartStep"
|
||||
assert "StreamTextStart" in event_types, "Missing StreamTextStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
assert "StreamTextEnd" in event_types, "Missing StreamTextEnd"
|
||||
assert "StreamFinishStep" in event_types, "Missing StreamFinishStep"
|
||||
assert "StreamFinish" in event_types, "Missing StreamFinish"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
@@ -175,16 +186,17 @@ async def test_streaming_heartbeat_timing():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
"""Test that error events have correct SSE structure."""
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
# Verify to_sse() strips code (AI SDK protocol compliance)
|
||||
sse = error.to_sse()
|
||||
assert '"errorText"' in sse
|
||||
assert '"code"' not in sse, "to_sse() must strip code field for AI SDK"
|
||||
|
||||
print("✅ Error structure verified (code stripped in SSE)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -326,20 +338,85 @@ async def test_stream_completeness():
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events (StreamFinish is published by processor)
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
# Check for all required event types
|
||||
assert any(isinstance(e, StreamStart) for e in events), "Missing StreamStart"
|
||||
assert any(
|
||||
isinstance(e, StreamStartStep) for e in events
|
||||
), "Missing StreamStartStep"
|
||||
assert any(
|
||||
isinstance(e, StreamTextStart) for e in events
|
||||
), "Missing StreamTextStart"
|
||||
assert any(
|
||||
isinstance(e, StreamTextDelta) for e in events
|
||||
), "Missing StreamTextDelta"
|
||||
assert any(isinstance(e, StreamTextEnd) for e in events), "Missing StreamTextEnd"
|
||||
assert any(
|
||||
isinstance(e, StreamFinishStep) for e in events
|
||||
), "Missing StreamFinishStep"
|
||||
assert any(isinstance(e, StreamFinish) for e in events), "Missing StreamFinish"
|
||||
|
||||
# Verify exactly one start
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
|
||||
)
|
||||
print(f"✅ Completeness: {len(events)} events, full protocol sequence")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_shows_retryable():
|
||||
"""Test __test_transient_error__ yields partial text then retryable StreamError."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-transient",
|
||||
message="please fail __test_transient_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should start with StreamStart
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have some partial text before the error
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0, "Should stream partial text before error"
|
||||
|
||||
# Should end with StreamError
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1, "Should have exactly one StreamError"
|
||||
assert error_events[0].code == "transient_api_error"
|
||||
assert "connection interrupted" in error_events[0].errorText.lower()
|
||||
|
||||
print(f"✅ Transient error: {len(text_events)} partial deltas + retryable error")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fatal_error_not_retryable():
|
||||
"""Test __test_fatal_error__ yields StreamError without retryable code."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-fatal",
|
||||
message="__test_fatal_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have StreamError with sdk_error code (not transient)
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1
|
||||
assert error_events[0].code == "sdk_error"
|
||||
assert "transient" not in error_events[0].code
|
||||
|
||||
# Should NOT have any text deltas (fatal errors fail immediately)
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) == 0, "Fatal error should not stream any text"
|
||||
|
||||
print("✅ Fatal error: immediate error, no partial text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,6 +472,8 @@ if __name__ == "__main__":
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_transient_error_shows_retryable())
|
||||
asyncio.run(test_fatal_error_not_retryable())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user