mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
56 Commits
feat/llm-a
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3db2a944f7 | ||
|
|
59192102a6 | ||
|
|
65cca9bef8 | ||
|
|
6b32e43d84 | ||
|
|
b73d05c23e | ||
|
|
8277cce835 | ||
|
|
57b17dc8e1 | ||
|
|
a20188ae59 | ||
|
|
c410be890e | ||
|
|
37d9863552 | ||
|
|
2f42ff9b47 | ||
|
|
914efc53e5 | ||
|
|
17e78ca382 | ||
|
|
7ba05366ed | ||
|
|
ca74f980c1 | ||
|
|
68f5d2ad08 | ||
|
|
2b3d730ca9 | ||
|
|
f28628e34b | ||
|
|
b6a027fd2b | ||
|
|
fb74fcf4a4 | ||
|
|
28b26dde94 | ||
|
|
d677978c90 | ||
|
|
a347c274b7 | ||
|
|
f79d8f0449 | ||
|
|
1bc48c55d5 | ||
|
|
9d0a31c0f1 | ||
|
|
9b086e39c6 | ||
|
|
5867e4d613 | ||
|
|
85f0d8353a | ||
|
|
f871717f68 | ||
|
|
f08e52dc86 | ||
|
|
500b345b3b | ||
|
|
995dd1b5f3 | ||
|
|
336114f217 | ||
|
|
866563ad25 | ||
|
|
e79928a815 | ||
|
|
1771ed3bef | ||
|
|
550fa5a319 | ||
|
|
8528dffbf2 | ||
|
|
8fbf6a4b09 | ||
|
|
239148596c | ||
|
|
a880d73481 | ||
|
|
80bfd64ffa | ||
|
|
0076ad2a1a | ||
|
|
edb3d322f0 | ||
|
|
9381057079 | ||
|
|
f21a36ca37 | ||
|
|
ee5382a064 | ||
|
|
b80e5ea987 | ||
|
|
3d4fcfacb6 | ||
|
|
32eac6d52e | ||
|
|
9762f4cde7 | ||
|
|
76901ba22f | ||
|
|
23b65939f3 | ||
|
|
1c27eaac53 | ||
|
|
923b164794 |
106
.claude/skills/open-pr/SKILL.md
Normal file
106
.claude/skills/open-pr/SKILL.md
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
@@ -17,6 +17,14 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
@@ -105,7 +113,9 @@ kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
|
||||
@@ -17,6 +17,16 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -34,6 +44,8 @@ gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
@@ -5,13 +5,96 @@ user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
@@ -53,14 +136,20 @@ Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the changed files to understand:
|
||||
1. What feature/fix does this PR implement?
|
||||
2. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
3. What are the key user-facing behaviors to test?
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
@@ -75,15 +164,21 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests
|
||||
1. [What should NOT happen]
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks.
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
@@ -233,7 +328,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests:
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
@@ -256,6 +351,27 @@ curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
@@ -346,59 +462,90 @@ agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --time
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
For each test scenario, record in `$RESULTS_DIR/test-report.md`:
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
|
||||
|
||||
For each screenshot:
|
||||
1. Use the `Read` tool to display the PNG file (Claude can read images)
|
||||
2. Write a 1-2 sentence explanation below it describing:
|
||||
- What page/state is being shown
|
||||
- What the screenshot proves (which test scenario it validates)
|
||||
- Any notable details visible in the UI
|
||||
|
||||
Format the output like this:
|
||||
|
||||
```markdown
|
||||
# E2E Test Report: PR #{N} — {title}
|
||||
Date: {date}
|
||||
Branch: {branch}
|
||||
Worktree: {path}
|
||||
### Screenshot 1: {descriptive title}
|
||||
[Read the PNG file here]
|
||||
|
||||
## Environment
|
||||
- Docker services: [list running containers]
|
||||
- API keys: OpenRouter={present/missing}, E2B={present/missing}
|
||||
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
|
||||
|
||||
## Test Results
|
||||
|
||||
### Scenario 1: {name}
|
||||
**Steps:**
|
||||
1. ...
|
||||
2. ...
|
||||
**Expected:** ...
|
||||
**Actual:** ...
|
||||
**Result:** PASS / FAIL
|
||||
**Screenshot:** {filename}.png
|
||||
**Logs:** (if relevant)
|
||||
|
||||
### Scenario 2: {name}
|
||||
...
|
||||
|
||||
## Summary
|
||||
- Total: X scenarios
|
||||
- Passed: Y
|
||||
- Failed: Z
|
||||
- Bugs found: [list]
|
||||
---
|
||||
```
|
||||
|
||||
Take screenshots at each significant step:
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{description}.png
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
## Step 6: Report results
|
||||
## Step 7: Post test report as PR comment with screenshots
|
||||
|
||||
After all tests complete, output a summary to the user:
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
1. Table of all scenarios with PASS/FAIL
|
||||
2. Screenshots of failures (read the PNG files to show them)
|
||||
3. Any bugs found with details
|
||||
4. Recommendations
|
||||
|
||||
### Post test results as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees).
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
@@ -406,93 +553,166 @@ REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot
|
||||
declare -a TREE_ENTRIES
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
|
||||
TREE_ENTRIES+=("-f" "tree[][path]=${SCREENSHOTS_DIR}/${BASENAME}" "-f" "tree[][mode]=100644" "-f" "tree[][type]=blob" "-f" "tree[][sha]=${BLOB_SHA}")
|
||||
done
|
||||
|
||||
# Step 2: Create a tree with all screenshot blobs
|
||||
# Build the tree JSON manually since gh api doesn't handle arrays well
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
TREE_SHA=$(echo "$TREE_JSON" | gh api "repos/${REPO}/git/trees" --input - -f base_tree="" --jq '.sha' 2>/dev/null \
|
||||
|| echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
|
||||
# Step 3: Create a commit pointing to that tree
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Step 4: Create or update the ref (branch) — no local checkout needed
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
# Step 5: Build image markdown and post the comment
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
IMAGE_MARKDOWN="$IMAGE_MARKDOWN
|
||||
"
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
${EXPLANATION}
|
||||
"
|
||||
done
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -f body="$(cat <<EOF
|
||||
## 🧪 E2E Test Report
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
$(cat $RESULTS_DIR/test-report.md)
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
### Screenshots
|
||||
${IMAGE_MARKDOWN}
|
||||
EOF
|
||||
)"
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, after finding a bug:
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
1. Identify the root cause in the code
|
||||
2. Fix it in the worktree
|
||||
3. Rebuild the affected service: `cd $PLATFORM_DIR && docker compose up --build -d {service_name}`
|
||||
4. Re-test the scenario
|
||||
5. If fix works, commit and push:
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
6. Continue testing remaining scenarios
|
||||
7. After all fixes, run the full test suite again to ensure no regressions
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find bug → fix code → rebuild service → re-test
|
||||
→ repeat until all scenarios pass
|
||||
→ commit + push all fixes
|
||||
→ run full re-test to verify
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
|
||||
195
.claude/skills/setup-repo/SKILL.md
Normal file
195
.claude/skills/setup-repo/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,8 +1,12 @@
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
@@ -53,8 +53,10 @@ AutoGPT Platform is a monorepo containing:
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
version = "7.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.0"
|
||||
version = "0.15.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -178,6 +178,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
@@ -61,6 +61,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
|
||||
@@ -121,36 +121,20 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium.
|
||||
# On amd64: install runtime libs + run `agent-browser install` to download
|
||||
# Chrome for Testing (pinned version, tested with Playwright).
|
||||
# On arm64: install system chromium package — Chrome for Testing has no ARM64
|
||||
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
|
||||
# script (below) to redirect agent-browser to the system binary.
|
||||
ARG TARGETARCH
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
RUN apt-get update \
|
||||
&& if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
apt-get install -y --no-install-recommends chromium fonts-liberation; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1; \
|
||||
fi \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
# On arm64 the system chromium is at /usr/bin/chromium; set
|
||||
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
|
||||
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
|
||||
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
|
||||
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
|
||||
> /usr/local/bin/entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/entrypoint.sh
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
@@ -173,5 +157,4 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -18,14 +18,22 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
@@ -91,18 +99,6 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -473,12 +469,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -486,28 +482,19 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -515,16 +502,7 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -597,11 +575,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -639,15 +617,23 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
reset_user_usage,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["copilot", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
user_id: Optional[str], email: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Resolve a user_id and email from the provided parameters.
|
||||
|
||||
Returns (user_id, email). Accepts either user_id or email; at least one
|
||||
must be provided. When both are provided, ``email`` takes precedence.
|
||||
"""
|
||||
if email:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No user found with the provided email."
|
||||
)
|
||||
return user.id, email
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either user_id or email query parameter is required.",
|
||||
)
|
||||
|
||||
# We have a user_id; try to look up their email for display purposes.
|
||||
# This is non-critical -- a failure should not block the response.
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
return user_id, resolved_email
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Get User Rate Limit",
|
||||
)
|
||||
async def get_user_rate_limit(
|
||||
user_id: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Get a user's current usage and effective rate limits. Admin-only.
|
||||
|
||||
Accepts either ``user_id`` or ``email`` as a query parameter.
|
||||
When ``email`` is provided the user is looked up by email first.
|
||||
"""
|
||||
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/reset",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Reset User Rate Limit Usage",
|
||||
)
|
||||
async def reset_user_rate_limit(
|
||||
user_id: str = Body(embed=True),
|
||||
reset_weekly: bool = Body(False, embed=True),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
|
||||
logger.info(
|
||||
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
|
||||
admin_user_id,
|
||||
user_id,
|
||||
reset_weekly,
|
||||
)
|
||||
|
||||
try:
|
||||
await reset_user_usage(user_id, reset_weekly=reset_weekly)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(rate_limit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
|
||||
|
||||
_TARGET_EMAIL = "target@example.com"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_usage_status(
|
||||
daily_used: int = 500_000, weekly_used: int = 3_000_000
|
||||
) -> CoPilotUsageStatus:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _patch_rate_limit_deps(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
daily_used: int = 500_000,
|
||||
weekly_used: int = 3_000_000,
|
||||
):
|
||||
"""Patch the common rate-limit + user-lookup dependencies."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting rate limit and usage for a user."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"get_rate_limit",
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test looking up rate limits via email instead of user_id."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test that looking up a non-existent email returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_rate_limit_no_params() -> None:
|
||||
"""Test that omitting both user_id and email returns 400."""
|
||||
response = client.get("/admin/rate_limit")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting only daily usage (default behaviour)."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_only",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting both daily and weekly usage."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id, "reset_weekly": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_and_weekly",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that Redis failure on reset returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Redis connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_rate_limit_email_lookup_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that failing to resolve a user email degrades gracefully."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection lost"),
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] is None
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that rate limit admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -7,6 +7,8 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -132,3 +134,40 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
@@ -30,8 +30,14 @@ from backend.copilot.model import (
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
increment_daily_reset_count,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
@@ -59,9 +65,16 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -69,8 +82,6 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -421,11 +432,187 @@ async def get_copilot_usage(
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
"""Response from resetting the daily rate limit."""
|
||||
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/reset",
|
||||
status_code=200,
|
||||
responses={
|
||||
400: {
|
||||
"description": "Bad Request (feature disabled or daily limit not reached)"
|
||||
},
|
||||
402: {"description": "Payment Required (insufficient credits)"},
|
||||
429: {
|
||||
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
|
||||
},
|
||||
503: {
|
||||
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
|
||||
},
|
||||
},
|
||||
)
|
||||
async def reset_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
"""
|
||||
cost = config.rate_limit_reset_cost
|
||||
if cost <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available.",
|
||||
)
|
||||
|
||||
if not settings.config.enable_credit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No daily limit is configured — nothing to reset.",
|
||||
)
|
||||
|
||||
# Check max daily resets. get_daily_reset_count returns None when Redis
|
||||
# is unavailable; reject the reset in that case to prevent unlimited
|
||||
# free resets when the counter store is down.
|
||||
reset_count = await get_daily_reset_count(user_id)
|
||||
if reset_count is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Unable to verify reset eligibility — please try again later.",
|
||||
)
|
||||
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You've used all {config.max_daily_resets} resets for today.",
|
||||
)
|
||||
|
||||
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
|
||||
if not await acquire_reset_lock(user_id):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="A reset is already in progress. Please try again.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="You have not reached your daily limit yet.",
|
||||
)
|
||||
|
||||
# If the weekly limit is also exhausted, resetting the daily counter
|
||||
# won't help — the user would still be blocked by the weekly limit.
|
||||
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
|
||||
)
|
||||
|
||||
# Charge credits.
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
try:
|
||||
remaining = await credit_model.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
reason="CoPilot daily rate limit reset",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient credits to reset your rate limit.",
|
||||
) from e
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
await credit_model.top_up_credits(user_id, cost)
|
||||
refunded = True
|
||||
logger.warning(
|
||||
"Refunded %d credits to user %s after Redis reset failure",
|
||||
cost,
|
||||
user_id[:8],
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"CRITICAL: Failed to refund %d credits to user %s "
|
||||
"after Redis reset failure — manual intervention required",
|
||||
cost,
|
||||
user_id[:8],
|
||||
exc_info=True,
|
||||
)
|
||||
if refunded:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed — please try again later. "
|
||||
"Your credits have not been charged.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed and the automatic refund "
|
||||
"also failed. Please contact support for assistance.",
|
||||
)
|
||||
|
||||
# Track the reset count for daily cap enforcement.
|
||||
await increment_daily_reset_count(user_id)
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -526,12 +713,16 @@ async def stream_chat_post(
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -894,6 +1085,47 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedTheme(BaseModel):
|
||||
"""A themed group of suggested prompts."""
|
||||
|
||||
name: str
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts grouped by theme."""
|
||||
|
||||
themes: list[SuggestedTheme]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts grouped by theme.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns empty themes list if no custom
|
||||
prompts are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None or not understanding.suggested_prompts:
|
||||
return SuggestedPromptsResponse(themes=[])
|
||||
|
||||
themes = [
|
||||
SuggestedTheme(name=name, prompts=prompts)
|
||||
for name, prompts in understanding.suggested_prompts.items()
|
||||
]
|
||||
return SuggestedPromptsResponse(themes=themes)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -368,6 +368,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -380,6 +381,7 @@ def test_usage_uses_config_limits(
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
@@ -388,6 +390,7 @@ def test_usage_uses_config_limits(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
)
|
||||
|
||||
|
||||
@@ -400,3 +403,69 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_themes(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with themed prompts gets them back as themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {
|
||||
"Learn": ["L1", "L2"],
|
||||
"Create": ["C1"],
|
||||
}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "themes" in data
|
||||
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
|
||||
assert themes_by_name["Learn"] == ["L1", "L2"]
|
||||
assert themes_by_name["Create"] == ["C1"]
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty themes list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but empty prompts gets empty themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
@@ -34,16 +34,21 @@ from backend.data.model import (
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -109,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -138,6 +144,19 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -204,34 +223,20 @@ async def callback(
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -242,19 +247,11 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -267,18 +264,21 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
return credential
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -288,16 +288,22 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
return credentials
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -332,15 +338,29 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent.
|
||||
|
||||
Uses a create-then-catch-UniqueViolationError-then-update pattern on
|
||||
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
|
||||
This is more robust than ``upsert`` because Prisma's upsert atomicity
|
||||
guarantees are not well-documented for all versions.
|
||||
"""
|
||||
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
|
||||
_include = library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Already exists — update to restore if previously soft-deleted/archived
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
if added_agent is None:
|
||||
raise NotFoundError(
|
||||
f"LibraryAgent for graph #{graph_model.id} "
|
||||
f"v{graph_model.version} not found after UniqueViolationError"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.errors
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"""When no matching LibraryAgent exists, create inserts a new one."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
created_agent = MagicMock(name="CreatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
assert create_data["User"] == {"connect": {"id": "user-id"}}
|
||||
assert create_data["AgentGraph"] == {
|
||||
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
|
||||
}
|
||||
assert create_data["isCreatedByUser"] is False
|
||||
assert create_data["useGraphIsActiveVersion"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"""UniqueViolationError on create falls back to update."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
updated_agent = MagicMock(name="UpdatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
MagicMock(), message="unique constraint"
|
||||
)
|
||||
)
|
||||
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "user-id",
|
||||
"agentGraphId": "graph-id",
|
||||
"agentGraphVersion": 2,
|
||||
}
|
||||
}
|
||||
update_data = update_call.kwargs["data"]
|
||||
assert update_data["isDeleted"] is False
|
||||
assert update_data["isArchived"] is False
|
||||
@@ -336,12 +336,15 @@ async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -433,32 +436,53 @@ async def create_library_agent(
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
prisma.models.LibraryAgent.prisma(tx).upsert(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_entry.id,
|
||||
"agentGraphVersion": graph_entry.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
"update": {
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"useGraphIsActiveVersion": True,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
@@ -582,7 +606,9 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -818,92 +844,38 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
@@ -143,15 +141,18 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -170,37 +171,27 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
assert create_call_args is not None
|
||||
|
||||
# Verify the main structure
|
||||
expected_data = {
|
||||
# Verify the create data structure
|
||||
create_data = create_call_args.kwargs["data"]
|
||||
expected_create = {
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
}
|
||||
|
||||
actual_data = create_call_args[1]["data"]
|
||||
# Check that all expected fields are present
|
||||
for key, value in expected_data.items():
|
||||
assert actual_data[key] == value
|
||||
for key, value in expected_create.items():
|
||||
assert create_data[key] == value
|
||||
|
||||
# Check that settings field is present and is a SafeJson object
|
||||
assert "settings" in actual_data
|
||||
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
|
||||
assert "settings" in create_data
|
||||
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
|
||||
|
||||
# Check include parameter
|
||||
assert create_call_args[1]["include"] == library_agent_include(
|
||||
assert create_call_args.kwargs["include"] == library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
@@ -224,3 +215,141 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_library_agent_uses_upsert():
|
||||
"""create_library_agent should use upsert (not create) to handle duplicates."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.version = 1
|
||||
mock_graph.user_id = "user-1"
|
||||
mock_graph.nodes = []
|
||||
mock_graph.sub_graphs = []
|
||||
|
||||
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_tx():
|
||||
yield None
|
||||
|
||||
with (
|
||||
patch("backend.api.features.library.db.transaction", fake_tx),
|
||||
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library.db.add_generated_agent_image",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
|
||||
|
||||
result = await db.create_library_agent(mock_graph, "user-1")
|
||||
|
||||
assert len(result) == 1
|
||||
upsert_call = mock_prisma.return_value.upsert.call_args
|
||||
assert upsert_call is not None
|
||||
# Verify the upsert where clause uses the composite unique key
|
||||
where = upsert_call.kwargs["where"]
|
||||
assert "userId_agentGraphId_agentGraphVersion" in where
|
||||
# Verify the upsert data has both create and update branches
|
||||
data = upsert_call.kwargs["data"]
|
||||
assert "create" in data
|
||||
assert "update" in data
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -391,6 +391,11 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -411,6 +416,57 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
|
||||
@@ -980,14 +980,16 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -997,6 +999,7 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -117,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -318,6 +324,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.rate_limit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -528,8 +539,11 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
@@ -15,6 +15,12 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -96,6 +102,50 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -184,7 +234,7 @@ class AutoPilotBlock(Block):
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
@@ -196,6 +246,7 @@ class AutoPilotBlock(Block):
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
@@ -209,14 +260,21 @@ class AutoPilotBlock(Block):
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
@@ -225,6 +283,7 @@ class AutoPilotBlock(Block):
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
@@ -271,6 +330,8 @@ class AutoPilotBlock(Block):
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -295,6 +356,13 @@ class AutoPilotBlock(Block):
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
@@ -312,6 +380,7 @@ class AutoPilotBlock(Block):
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
@@ -374,3 +443,78 @@ def _reset_recursion(
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"continuous_read": False,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -28,9 +28,9 @@ class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
It takes in a value, name, and description.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
It outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -47,12 +47,6 @@ class AgentInputBlock(Block):
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
@@ -65,10 +59,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
return copy.deepcopy(self.get_field_schema("value"))
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
@@ -86,18 +77,16 @@ class AgentInputBlock(Block):
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"value": 42,
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"description": "Example numeric input.",
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
("result", 42),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
@@ -245,13 +234,11 @@ class AgentShortTextInputBlock(AgentInputBlock):
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -285,13 +272,11 @@ class AgentLongTextInputBlock(AgentInputBlock):
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -325,13 +310,11 @@ class AgentNumberInputBlock(AgentInputBlock):
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -501,6 +484,12 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
|
||||
@@ -49,6 +49,9 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -101,6 +104,18 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
if isinstance(value, str) and "/" in value:
|
||||
stripped = value.split("/", 1)[1]
|
||||
try:
|
||||
return cls(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
@@ -709,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -724,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -891,65 +912,60 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -962,6 +978,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1021,12 +1039,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1063,12 +1077,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1098,6 +1108,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1134,6 +1146,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -1462,7 +1477,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
@@ -1992,6 +2016,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -28,46 +23,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,13 +103,10 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -185,32 +137,28 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -242,24 +190,22 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -282,32 +228,33 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -335,13 +282,10 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -367,24 +311,21 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
@@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks._base import Block, BlockSchemaInput
|
||||
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
@@ -279,3 +281,66 @@ class TestAutoCredentialsFieldsValidation:
|
||||
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
|
||||
for backward compatibility with existing agent JSON."""
|
||||
legacy_data = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
instance = AgentInputBlock.Input.model_construct(**legacy_data)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
"enum" not in schema
|
||||
), "AgentInputBlock should not produce enum from legacy placeholder_values"
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with legacy placeholder_values
|
||||
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
|
||||
legacy_input_default = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentInputBlock.Input, legacy_input_default),
|
||||
)
|
||||
url_props = result["properties"]["url"]
|
||||
assert (
|
||||
"enum" not in url_props
|
||||
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
@@ -207,6 +207,51 @@ class TestXMLParserBlockSecurity:
|
||||
pass
|
||||
|
||||
|
||||
class TestXMLParserBlockSyntaxErrors:
|
||||
"""XML syntax errors should raise ValueError (not SyntaxError).
|
||||
|
||||
This ensures the base Block.execute() wraps them as BlockExecutionError
|
||||
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
|
||||
Sentry).
|
||||
"""
|
||||
|
||||
async def test_unclosed_tag_raises_value_error(self):
|
||||
"""Unclosed tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "<root><unclosed>"
|
||||
|
||||
with pytest.raises(ValueError, match="Unclosed tag"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_unexpected_closing_tag_raises_value_error(self):
|
||||
"""Extra closing tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "</unexpected>"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_empty_xml_raises_value_error(self):
|
||||
"""Empty XML input should raise ValueError."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
with pytest.raises(ValueError, match="XML input is empty"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
|
||||
pass
|
||||
|
||||
async def test_syntax_error_from_parser_becomes_value_error(self):
|
||||
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
|
||||
block = XMLParserBlock()
|
||||
# Malformed XML that might trigger a SyntaxError from the parser
|
||||
bad_xml = "<root><child>no closing"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
|
||||
# but not at the type level. Cast once here to avoid per-test suppressors.
|
||||
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
class TestLLMStatsTracking:
|
||||
"""Test that LLM blocks correctly track token usage statistics."""
|
||||
@@ -479,6 +488,154 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == {"result": "test"}
|
||||
|
||||
|
||||
class TestAIConversationBlockValidation:
|
||||
"""Test that AIConversationBlock validates inputs before calling the LLM."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_and_empty_prompt_raises_error(self):
|
||||
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_with_prompt_succeeds(self):
|
||||
"""Empty messages but a non-empty prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "OK"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="Hello, how are you?",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
|
||||
"""Non-empty messages with no prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "response from conversation"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "response from conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_content_raises_error(self):
|
||||
"""Messages with empty content strings should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_whitespace_content_raises_error(self):
|
||||
"""Messages with whitespace-only content should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": " "}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_entry_raises_error(self):
|
||||
"""Messages list containing None should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[None],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_dict_raises_error(self):
|
||||
"""Messages list containing empty dict should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_content_raises_error(self):
|
||||
"""Messages with content=None should not crash with AttributeError."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": None}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
class TestAITextSummarizerValidation:
|
||||
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
|
||||
|
||||
@@ -655,3 +812,178 @@ class TestAITextSummarizerValidation:
|
||||
error_message = str(exc_info.value)
|
||||
assert "Expected a string summary" in error_message
|
||||
assert "received dict" in error_message
|
||||
|
||||
|
||||
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
|
||||
"""Create an anthropic.APIStatusError with the given status code."""
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code, request=request)
|
||||
return anthropic.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
|
||||
"""Create an openai.APIStatusError with the given status code."""
|
||||
response = httpx.Response(
|
||||
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
|
||||
)
|
||||
return openai.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
class TestUserErrorStatusCodeHandling:
|
||||
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
|
||||
and are logged as warnings, while server errors (500) trigger retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 Anthropic errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 OpenAI errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_openai_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_error_retries(self):
|
||||
"""500 errors should be retried (not break immediately)."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(500)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count > 1
|
||||
), f"Expected multiple retry attempts for 500, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_error_logs_warning_not_exception(self):
|
||||
"""User-caused errors should log with logger.warning, not logger.exception."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
raise _make_anthropic_status_error(401)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(llm.logger, "warning") as mock_warning,
|
||||
patch.object(llm.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
|
||||
class TestLlmModelMissing:
|
||||
"""Test that LlmModel handles provider-prefixed model names."""
|
||||
|
||||
def test_provider_prefixed_model_resolves(self):
|
||||
"""Provider-prefixed model string should resolve to the correct enum member."""
|
||||
assert (
|
||||
llm.LlmModel("anthropic/claude-sonnet-4-6")
|
||||
== llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
)
|
||||
|
||||
def test_bare_model_still_works(self):
|
||||
"""Bare (non-prefixed) model string should still resolve correctly."""
|
||||
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
|
||||
def test_invalid_prefixed_model_raises(self):
|
||||
"""Unknown provider-prefixed model string should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
llm.LlmModel("invalid/nonexistent-model")
|
||||
|
||||
def test_slash_containing_value_direct_lookup(self):
|
||||
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
|
||||
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
|
||||
def test_double_prefixed_slash_model(self):
|
||||
"""Double-prefixed value should still resolve by stripping first prefix."""
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -57,7 +57,7 @@ async def execute_graph(
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -120,7 +120,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -169,7 +169,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -198,12 +198,12 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
async def test_orchestrator_tracks_llm_stats():
|
||||
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the llm.llm_call function to return controlled data
|
||||
mock_response = MagicMock()
|
||||
@@ -224,14 +224,14 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -274,12 +274,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
async def test_orchestrator_parameter_validation():
|
||||
"""Test that OrchestratorBlock correctly validates tool call parameters."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
@@ -327,13 +327,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -394,13 +394,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -454,13 +454,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -518,13 +518,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -562,12 +562,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
async def test_orchestrator_raw_response_conversion():
|
||||
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
@@ -637,7 +637,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
@@ -646,7 +646,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -715,12 +715,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -771,12 +771,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -811,12 +811,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
async def test_orchestrator_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
@@ -893,7 +893,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -929,7 +929,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -969,12 +969,12 @@ async def test_smart_decision_maker_agent_mode():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
@@ -1018,7 +1018,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -1060,12 +1060,12 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1074,13 +1074,14 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1091,12 +1092,12 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
async def test_orchestrator_falls_back_to_block_name():
|
||||
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1105,13 +1106,14 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1122,11 +1124,11 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
async def test_orchestrator_uses_customized_name_for_agents():
|
||||
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1152,10 +1154,10 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1166,11 +1168,11 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1196,10 +1198,10 @@ async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -23,24 +23,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_list_fields():
|
||||
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -86,18 +86,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for SmartDecisionMakerBlock compatibility with the OpenAI Responses API.
|
||||
"""Tests for OrchestratorBlock compatibility with the OpenAI Responses API.
|
||||
|
||||
The SmartDecisionMakerBlock manages conversation history in the Chat Completions
|
||||
The OrchestratorBlock manages conversation history in the Chat Completions
|
||||
format, but OpenAI models now use the Responses API which has a fundamentally
|
||||
different conversation structure. These tests document:
|
||||
|
||||
@@ -27,8 +27,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
from backend.blocks.orchestrator import (
|
||||
OrchestratorBlock,
|
||||
_combine_tool_responses,
|
||||
_convert_raw_response_to_dict,
|
||||
_create_tool_response,
|
||||
@@ -733,7 +733,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_no_reasoning_no_tools(self):
|
||||
"""Dict raw_response, no reasoning → appends assistant dict."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "hi"})
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -741,7 +741,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_no_tool_calls(self):
|
||||
"""Reasoning present, no tool calls → reasoning prepended."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
{"role": "assistant", "content": "answer"},
|
||||
@@ -757,7 +757,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_and_anthropic_tool_calls(self):
|
||||
"""Reasoning + Anthropic tool_use in content → reasoning skipped."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
raw = {
|
||||
"role": "assistant",
|
||||
@@ -772,7 +772,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_with_tool_outputs(self):
|
||||
"""Tool outputs → extended onto prompt."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": None})
|
||||
outputs = [{"role": "tool", "tool_call_id": "call_1", "content": "r"}]
|
||||
@@ -782,7 +782,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_without_tool_outputs(self):
|
||||
"""No tool outputs → only assistant message appended."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "done"})
|
||||
block._update_conversation(prompt, resp, None)
|
||||
@@ -790,7 +790,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_string_raw_response(self):
|
||||
"""Ollama string → wrapped as assistant dict."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response("hello from ollama")
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -800,7 +800,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_text_response_produces_valid_items(self):
|
||||
"""Responses API text response → conversation items must have valid role."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "user"},
|
||||
@@ -820,7 +820,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_function_call_produces_valid_items(self):
|
||||
"""Responses API function_call → conversation items must have valid type."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
_MockResponse(output=[_MockFunctionCall("tool", "{}", call_id="call_1")])
|
||||
@@ -856,7 +856,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# First response: tool call
|
||||
mock_tc = MagicMock()
|
||||
@@ -936,7 +936,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -945,7 +945,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||
):
|
||||
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
inp = OrchestratorBlock.Input(
|
||||
prompt="Improve this",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -992,7 +992,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"""Traditional mode: the yielded conversation must contain only valid items."""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.function.name = "my_tool"
|
||||
@@ -1028,7 +1028,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock, return_value=resp
|
||||
), patch.object(block, "_create_tool_node_signatures", return_value=tool_sigs):
|
||||
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
inp = OrchestratorBlock.Input(
|
||||
prompt="Do it",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
raise ValueError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
raise ValueError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
@@ -76,4 +76,7 @@ class XMLParserBlock(Block):
|
||||
except ValueError as val_e:
|
||||
raise ValueError(f"Validation error for dict:{val_e}") from val_e
|
||||
except SyntaxError as syn_e:
|
||||
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
# Raise as ValueError so the base Block.execute() wraps it as
|
||||
# BlockExecutionError (expected user-caused failure) instead of
|
||||
# BlockUnknownError (unexpected platform error that alerts Sentry).
|
||||
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
|
||||
@@ -9,11 +9,14 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -48,7 +51,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
ToolCallResult,
|
||||
tool_call_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +72,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaselineStreamState:
|
||||
"""Mutable state shared between the tool-call loop callbacks.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
|
||||
can be module-level functions instead of deeply nested closures.
|
||||
"""
|
||||
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
|
||||
|
||||
async def _baseline_llm_caller(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
) -> LLMLoopResponse:
|
||||
"""Stream an OpenAI-compatible response and collect results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
state.pending_events.append(StreamStartStep())
|
||||
|
||||
round_text = ""
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.usage:
|
||||
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
if delta.content:
|
||||
if not state.text_started:
|
||||
state.pending_events.append(StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += delta.content
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=delta.content)
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block
|
||||
if state.text_started:
|
||||
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
# Always emit StreamFinishStep to match the StreamStartStep,
|
||||
# even if an exception occurred during streaming.
|
||||
state.pending_events.append(StreamFinishStep())
|
||||
|
||||
# Convert to shared format
|
||||
llm_tool_calls = [
|
||||
LLMToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["name"],
|
||||
arguments=tc["arguments"] or "{}",
|
||||
)
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
|
||||
return LLMLoopResponse(
|
||||
response_text=round_text or None,
|
||||
tool_calls=llm_tool_calls,
|
||||
raw_response=None, # Not needed for baseline conversation updater
|
||||
prompt_tokens=0, # Tracked via state accumulators
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
async def _baseline_tool_executor(
|
||||
tool_call: LLMToolCall,
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> ToolCallResult:
|
||||
"""Execute a tool via the copilot tool registry.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
tool_call_id = tool_call.id
|
||||
tool_name = tool_call.name
|
||||
raw_args = tool_call.arguments or "{}"
|
||||
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=parse_error,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
state.pending_events.append(
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
state.pending_events.append(result)
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=tool_output,
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=error_output,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
if tool_results:
|
||||
# Build assistant message with tool_calls
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if response.response_text:
|
||||
assistant_msg["content"] = response.response_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": tc.arguments},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
@@ -219,191 +473,32 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
state = _BaselineStreamState()
|
||||
|
||||
# Bind extracted module-level callbacks to this request's state/session
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
|
||||
_bound_tool_executor = partial(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
@@ -418,11 +513,28 @@ async def stream_chat_completion_baseline(
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
# Close any open text block. The llm_caller's finally block
|
||||
# already appended StreamFinishStep to pending_events, so we must
|
||||
# insert StreamTextEnd *before* StreamFinishStep to preserve the
|
||||
# protocol ordering:
|
||||
# StreamStartStep -> StreamTextStart -> ...deltas... ->
|
||||
# StreamTextEnd -> StreamFinishStep
|
||||
# Appending (or yielding directly) would place it after
|
||||
# StreamFinishStep, violating the protocol.
|
||||
if state.text_started:
|
||||
# Find the last StreamFinishStep and insert before it.
|
||||
insert_pos = len(state.pending_events)
|
||||
for i in range(len(state.pending_events) - 1, -1, -1):
|
||||
if isinstance(state.pending_events[i], StreamFinishStep):
|
||||
insert_pos = i
|
||||
break
|
||||
state.pending_events.insert(
|
||||
insert_pos, StreamTextEnd(id=state.text_block_id)
|
||||
)
|
||||
# Drain pending events in correct order
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
@@ -442,26 +554,21 @@ async def stream_chat_completion_baseline(
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
state.turn_prompt_tokens == 0
|
||||
and state.turn_completion_tokens == 0
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
state.turn_prompt_tokens,
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
@@ -471,15 +578,15 @@ async def stream_chat_completion_baseline(
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
if state.assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
ChatMessage(role="assistant", content=state.assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
@@ -491,11 +598,11 @@ async def stream_chat_completion_baseline(
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -91,6 +91,20 @@ class ChatConfig(BaseSettings):
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Cost (in credits / cents) to reset the daily rate limit using credits.
|
||||
# When a user hits their daily limit, they can spend this amount to reset
|
||||
# the daily counter and keep working. Set to 0 to disable the feature.
|
||||
rate_limit_reset_cost: int = Field(
|
||||
default=500,
|
||||
ge=0,
|
||||
description="Credit cost (in cents) for resetting the daily rate limit. 0 = disabled.",
|
||||
)
|
||||
max_daily_resets: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
description="Maximum number of credit-based rate limit resets per user per day. 0 = unlimited.",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -164,7 +178,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
|
||||
@@ -17,6 +17,9 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
@@ -43,6 +46,12 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Current execution's capability filter. None means "no restrictions".
|
||||
# Set by set_execution_context(); read by run_block and service.py.
|
||||
_current_permissions: "ContextVar[CopilotPermissions | None]" = ContextVar(
|
||||
"_current_permissions", default=None
|
||||
)
|
||||
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
@@ -63,6 +72,7 @@ def set_execution_context(
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
@@ -70,6 +80,7 @@ def set_execution_context(
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
@@ -77,6 +88,11 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_permissions() -> "CopilotPermissions | None":
|
||||
"""Return the capability filter for the current execution, or None if unrestricted."""
|
||||
return _current_permissions.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
@@ -88,17 +104,32 @@ def get_sdk_cwd() -> str:
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
E2B_ALLOWED_DIRS: tuple[str, ...] = (E2B_WORKDIR, "/tmp")
|
||||
E2B_ALLOWED_DIRS_STR: str = " or ".join(E2B_ALLOWED_DIRS)
|
||||
|
||||
|
||||
def is_within_allowed_dirs(path: str) -> bool:
|
||||
"""Return True if *path* is within one of the allowed sandbox directories."""
|
||||
for allowed in E2B_ALLOWED_DIRS:
|
||||
if path == allowed or path.startswith(allowed + "/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
"""Normalise *path* to an absolute sandbox path under an allowed directory.
|
||||
|
||||
Allowed directories: ``/home/user`` and ``/tmp``.
|
||||
Relative paths are resolved against ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
if not is_within_allowed_dirs(normalized):
|
||||
raise ValueError(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(path)}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.copilot.context import (
|
||||
SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_permissions,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
@@ -18,6 +19,7 @@ from backend.copilot.context import (
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
@@ -61,6 +63,19 @@ def test_get_current_sandbox_returns_set_value():
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_set_and_get_current_permissions():
|
||||
"""set_execution_context stores permissions; get_current_permissions returns it."""
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
set_execution_context("u1", _make_session(), permissions=perms)
|
||||
assert get_current_permissions() is perms
|
||||
|
||||
|
||||
def test_get_current_permissions_defaults_to_none():
|
||||
"""get_current_permissions returns None when no permissions have been set."""
|
||||
set_execution_context("u1", _make_session())
|
||||
assert get_current_permissions() is None
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
@@ -183,10 +198,32 @@ def test_resolve_sandbox_path_normalizes_dots():
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_allowed():
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_nested():
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_itself():
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_escape_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_tmp_prefix_collision_raises():
|
||||
with pytest.raises(ValueError):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
@@ -18,7 +18,7 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
if msg.get("duration_ms") is not None:
|
||||
data["durationMs"] = msg["duration_ms"]
|
||||
|
||||
messages_data.append(data)
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
@@ -359,3 +362,22 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
order={"sequence": "desc"},
|
||||
)
|
||||
if last_msg:
|
||||
await PrismaChatMessage.prisma().update(
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
@@ -14,7 +14,7 @@ import time
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
@@ -23,6 +23,7 @@ from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
@@ -153,8 +154,6 @@ class CoPilotProcessor:
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
@@ -268,35 +267,37 @@ class CoPilotProcessor:
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
# publishing (shared with collect_copilot_response).
|
||||
raw_stream = stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
turn_id=entry.turn_id,
|
||||
stream=raw_stream,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
# Capture StreamError so mark_session_completed receives
|
||||
# the error message (stream_and_publish yields but does
|
||||
# not publish StreamError — that's done by mark_session_completed).
|
||||
if isinstance(chunk, StreamError):
|
||||
error_msg = chunk.errorText
|
||||
break
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
log.info("Stream cancelled by user")
|
||||
|
||||
@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
430
autogpt_platform/backend/backend/copilot/permissions.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Copilot execution permissions — tool and block allow/deny filtering.
|
||||
|
||||
:class:`CopilotPermissions` is the single model used everywhere:
|
||||
|
||||
- ``AutoPilotBlock`` reads four block-input fields and builds one instance.
|
||||
- ``stream_chat_completion_sdk`` applies it when constructing
|
||||
``ClaudeAgentOptions.allowed_tools`` / ``disallowed_tools``.
|
||||
- ``run_block`` reads it from the contextvar to gate block execution.
|
||||
- Recursive (sub-agent) invocations merge parent and child so children
|
||||
can only be *more* restrictive, never more permissive.
|
||||
|
||||
Tool names
|
||||
----------
|
||||
Users specify the **short name** as it appears in ``TOOL_REGISTRY`` (e.g.
|
||||
``run_block``, ``web_fetch``) or as an SDK built-in (e.g. ``Read``,
|
||||
``Task``, ``WebSearch``). Internally these are mapped to the full SDK
|
||||
format (``mcp__copilot__run_block``, ``Read``, …) by
|
||||
:func:`apply_tool_permissions`.
|
||||
|
||||
Block identifiers
|
||||
-----------------
|
||||
Each entry in ``blocks`` may be one of:
|
||||
|
||||
- A **full UUID** (``c069dc6b-c3ed-4c12-b6e5-d47361e64ce6``)
|
||||
- A **partial UUID** — the first 8-character hex segment (``c069dc6b``)
|
||||
- A **block name** (case-insensitive, e.g. ``"HTTP Request"``)
|
||||
|
||||
:func:`validate_block_identifiers` resolves all entries against the live
|
||||
block registry and returns any that could not be matched.
|
||||
|
||||
Semantics
|
||||
---------
|
||||
``tools_exclude=True`` (default) — ``tools`` is a **blacklist**; listed
|
||||
tools are denied and everything else is allowed. An empty list means
|
||||
"allow all" (no filtering).
|
||||
|
||||
``tools_exclude=False`` — ``tools`` is a **whitelist**; only listed tools
|
||||
are allowed.
|
||||
|
||||
``blocks_exclude`` follows the same pattern for ``blocks``.
|
||||
|
||||
Recursion inheritance
|
||||
---------------------
|
||||
:meth:`CopilotPermissions.merged_with_parent` produces a new instance that
|
||||
is at most as permissive as the parent:
|
||||
|
||||
- Tools: effective-allowed sets are intersected then stored as a whitelist.
|
||||
- Blocks: the parent is stored in ``_parent`` and consulted during every
|
||||
:meth:`is_block_allowed` call so both constraints must pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — single source of truth for all accepted tool names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Literal type combining all valid tool names — used by AutoPilotBlock.Input
|
||||
# so the frontend renders a multi-select dropdown.
|
||||
# This is the SINGLE SOURCE OF TRUTH. All other name sets are derived from it.
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
"browser_screenshot",
|
||||
"connect_integration",
|
||||
"continue_run_block",
|
||||
"create_agent",
|
||||
"create_feature_request",
|
||||
"create_folder",
|
||||
"customize_agent",
|
||||
"delete_folder",
|
||||
"delete_workspace_file",
|
||||
"edit_agent",
|
||||
"find_agent",
|
||||
"find_block",
|
||||
"find_library_agent",
|
||||
"fix_agent_graph",
|
||||
"get_agent_building_guide",
|
||||
"get_doc_page",
|
||||
"get_mcp_guide",
|
||||
"list_folders",
|
||||
"list_workspace_files",
|
||||
"move_agents_to_folder",
|
||||
"move_folder",
|
||||
"read_workspace_file",
|
||||
"run_agent",
|
||||
"run_block",
|
||||
"run_mcp_tool",
|
||||
"search_docs",
|
||||
"search_feature_requests",
|
||||
"update_folder",
|
||||
"validate_agent_graph",
|
||||
"view_agent_output",
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Read",
|
||||
"Task",
|
||||
"TodoWrite",
|
||||
"WebSearch",
|
||||
"Write",
|
||||
]
|
||||
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
PLATFORM_TOOL_NAMES: frozenset[str] = ALL_TOOL_NAMES - SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
# Compiled regex patterns for block identifier classification.
|
||||
_FULL_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PARTIAL_UUID_RE = re.compile(r"^[0-9a-f]{8}$", re.IGNORECASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — block identifier matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _block_matches(identifier: str, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if *identifier* resolves to the given block.
|
||||
|
||||
Resolution order:
|
||||
1. Full UUID — exact case-insensitive match against *block_id*.
|
||||
2. Partial UUID (8 hex chars, first segment) — prefix match.
|
||||
3. Name — case-insensitive equality against *block_name*.
|
||||
"""
|
||||
ident = identifier.strip()
|
||||
if _FULL_UUID_RE.match(ident):
|
||||
return ident.lower() == block_id.lower()
|
||||
if _PARTIAL_UUID_RE.match(ident):
|
||||
return block_id.lower().startswith(ident.lower())
|
||||
return ident.lower() == block_name.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CopilotPermissions(BaseModel):
|
||||
"""Capability filter for a single copilot execution.
|
||||
|
||||
Attributes:
|
||||
tools: Tool names to filter (short names, e.g. ``run_block``).
|
||||
tools_exclude: When True (default) ``tools`` is a blacklist;
|
||||
when False it is a whitelist. Ignored when *tools* is empty.
|
||||
blocks: Block identifiers (name, full UUID, or 8-char partial UUID).
|
||||
blocks_exclude: Same semantics as *tools_exclude* but for blocks.
|
||||
"""
|
||||
|
||||
tools: list[str] = []
|
||||
tools_exclude: bool = True
|
||||
blocks: list[str] = []
|
||||
blocks_exclude: bool = True
|
||||
|
||||
# Private: parent permissions for recursion inheritance.
|
||||
# Set only by merged_with_parent(); never exposed in block input schema.
|
||||
_parent: CopilotPermissions | None = PrivateAttr(default=None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def effective_allowed_tools(self, all_tools: frozenset[str]) -> frozenset[str]:
|
||||
"""Compute the set of short tool names that are permitted.
|
||||
|
||||
Args:
|
||||
all_tools: Universe of valid short tool names.
|
||||
|
||||
Returns:
|
||||
Subset of *all_tools* that pass the filter.
|
||||
"""
|
||||
if not self.tools:
|
||||
return frozenset(all_tools)
|
||||
tool_set = frozenset(self.tools)
|
||||
if self.tools_exclude:
|
||||
return all_tools - tool_set
|
||||
return all_tools & tool_set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Block helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_block_allowed(self, block_id: str, block_name: str) -> bool:
|
||||
"""Return True if the block may be executed under these permissions.
|
||||
|
||||
Checks this instance first, then consults the parent (if any) so
|
||||
the entire inheritance chain is respected.
|
||||
"""
|
||||
if not self._check_block_locally(block_id, block_name):
|
||||
return False
|
||||
if self._parent is not None:
|
||||
return self._parent.is_block_allowed(block_id, block_name)
|
||||
return True
|
||||
|
||||
def _check_block_locally(self, block_id: str, block_name: str) -> bool:
|
||||
"""Check *only* this instance's block filter (ignores parent)."""
|
||||
if not self.blocks:
|
||||
return True # No filter → allow all
|
||||
matched = any(
|
||||
_block_matches(identifier, block_id, block_name)
|
||||
for identifier in self.blocks
|
||||
)
|
||||
return not matched if self.blocks_exclude else matched
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recursion / merging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def merged_with_parent(
|
||||
self,
|
||||
parent: CopilotPermissions,
|
||||
all_tools: frozenset[str],
|
||||
) -> CopilotPermissions:
|
||||
"""Return a new instance that is at most as permissive as *parent*.
|
||||
|
||||
- Tools: intersection of effective-allowed sets, stored as a whitelist.
|
||||
- Blocks: parent is stored internally; both constraints are applied
|
||||
during :meth:`is_block_allowed`.
|
||||
"""
|
||||
merged_tools = self.effective_allowed_tools(
|
||||
all_tools
|
||||
) & parent.effective_allowed_tools(all_tools)
|
||||
result = CopilotPermissions(
|
||||
tools=sorted(merged_tools),
|
||||
tools_exclude=False,
|
||||
blocks=self.blocks,
|
||||
blocks_exclude=self.blocks_exclude,
|
||||
)
|
||||
result._parent = parent
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True when no filtering is configured (allow-all passthrough)."""
|
||||
return not self.tools and not self.blocks and self._parent is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def all_known_tool_names() -> frozenset[str]:
|
||||
"""Return all short tool names accepted in *tools*.
|
||||
|
||||
Returns the pre-computed ``ALL_TOOL_NAMES`` set (derived from the
|
||||
``ToolName`` Literal). On first call, also verifies consistency with
|
||||
the live ``TOOL_REGISTRY``.
|
||||
"""
|
||||
_assert_tool_names_consistent()
|
||||
return ALL_TOOL_NAMES
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str]) -> list[str]:
|
||||
"""Return entries in *tools* that are not valid tool names.
|
||||
|
||||
Args:
|
||||
tools: List of short tool name strings to validate.
|
||||
|
||||
Returns:
|
||||
List of invalid names (empty if all are valid).
|
||||
"""
|
||||
return [t for t in tools if t not in ALL_TOOL_NAMES]
|
||||
|
||||
|
||||
_tool_names_checked = False
|
||||
|
||||
|
||||
def _assert_tool_names_consistent() -> None:
|
||||
"""Verify that ``PLATFORM_TOOL_NAMES`` matches ``TOOL_REGISTRY`` keys.
|
||||
|
||||
Called once lazily (TOOL_REGISTRY has heavy imports). Raises
|
||||
``AssertionError`` with a helpful diff if they diverge.
|
||||
"""
|
||||
global _tool_names_checked
|
||||
if _tool_names_checked:
|
||||
return
|
||||
_tool_names_checked = True
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
registry_keys: frozenset[str] = frozenset(TOOL_REGISTRY.keys())
|
||||
declared: frozenset[str] = PLATFORM_TOOL_NAMES
|
||||
if registry_keys != declared:
|
||||
missing = registry_keys - declared
|
||||
extra = declared - registry_keys
|
||||
parts: list[str] = [
|
||||
"PLATFORM_TOOL_NAMES in permissions.py is out of sync with TOOL_REGISTRY."
|
||||
]
|
||||
if missing:
|
||||
parts.append(f" Missing from PLATFORM_TOOL_NAMES: {sorted(missing)}")
|
||||
if extra:
|
||||
parts.append(f" Extra in PLATFORM_TOOL_NAMES: {sorted(extra)}")
|
||||
parts.append(" Update the ToolName Literal to match.")
|
||||
raise AssertionError("\n".join(parts))
|
||||
|
||||
|
||||
async def validate_block_identifiers(
|
||||
identifiers: list[str],
|
||||
) -> list[str]:
|
||||
"""Resolve each block identifier and return those that could not be matched.
|
||||
|
||||
Args:
|
||||
identifiers: List of block identifiers (name, full UUID, or partial UUID).
|
||||
|
||||
Returns:
|
||||
List of identifiers that matched no known block.
|
||||
"""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# get_blocks() returns dict[block_id_str, BlockClass]; instantiate once to get names.
|
||||
block_registry = get_blocks()
|
||||
block_info = {bid: cls().name for bid, cls in block_registry.items()}
|
||||
invalid: list[str] = []
|
||||
for ident in identifiers:
|
||||
matched = any(
|
||||
_block_matches(ident, bid, bname) for bid, bname in block_info.items()
|
||||
)
|
||||
if not matched:
|
||||
invalid.append(ident)
|
||||
return invalid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK tool-list application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_tool_permissions(
|
||||
permissions: CopilotPermissions,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
|
||||
|
||||
Takes the base allowed/disallowed lists from
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
|
||||
applies *permissions* on top.
|
||||
|
||||
Returns:
|
||||
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
|
||||
possibly-narrowed list to pass to ``ClaudeAgentOptions.allowed_tools``
|
||||
and *extra_disallowed* is the list to pass to
|
||||
``ClaudeAgentOptions.disallowed_tools``.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
if permissions.is_empty():
|
||||
return base_allowed, base_disallowed
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
effective = permissions.effective_allowed_tools(all_tools)
|
||||
|
||||
# In E2B mode, SDK built-in file tools (Read, Write, Edit, Glob, Grep)
|
||||
# are replaced by MCP equivalents (read_file, write_file, ...).
|
||||
# Map each SDK built-in name to its E2B MCP name so users can use the
|
||||
# familiar names in their permissions and the E2B tools are included.
|
||||
_SDK_TO_E2B: dict[str, str] = {}
|
||||
if use_e2b:
|
||||
from backend.copilot.sdk.e2b_file_tools import E2B_FILE_TOOL_NAMES
|
||||
|
||||
_SDK_TO_E2B = dict(
|
||||
zip(
|
||||
["Read", "Write", "Edit", "Glob", "Grep"],
|
||||
E2B_FILE_TOOL_NAMES,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Build an updated allowed list by mapping short names → SDK names and
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_E2B:
|
||||
# E2B mode: map SDK built-in file tool to its MCP equivalent.
|
||||
names.append(f"{MCP_TOOL_PREFIX}{_SDK_TO_E2B[short]}")
|
||||
else:
|
||||
names.append(short) # SDK built-in — used as-is
|
||||
return names
|
||||
|
||||
# short names permitted by permissions
|
||||
permitted_sdk: set[str] = set()
|
||||
for s in effective:
|
||||
permitted_sdk.update(to_sdk_names(s))
|
||||
# Always include the internal Read tool (used by SDK for large/truncated outputs)
|
||||
permitted_sdk.add(f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}")
|
||||
|
||||
filtered_allowed = [t for t in base_allowed if t in permitted_sdk]
|
||||
|
||||
# Extra disallowed = tools that were in base_allowed but are now removed
|
||||
removed = set(base_allowed) - set(filtered_allowed)
|
||||
extra_disallowed = list(set(base_disallowed) | removed)
|
||||
|
||||
return filtered_allowed, extra_disallowed
|
||||
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
579
autogpt_platform/backend/backend/copilot/permissions_test.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for CopilotPermissions — tool/block capability filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.permissions import (
|
||||
ALL_TOOL_NAMES,
|
||||
PLATFORM_TOOL_NAMES,
|
||||
SDK_BUILTIN_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
_block_matches,
|
||||
all_known_tool_names,
|
||||
apply_tool_permissions,
|
||||
validate_block_identifiers,
|
||||
validate_tool_names,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _block_matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockMatches:
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
def test_full_uuid_match(self):
|
||||
assert _block_matches(self.BLOCK_ID, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_case_insensitive(self):
|
||||
assert _block_matches(self.BLOCK_ID.upper(), self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_no_match(self):
|
||||
other = "aaaaaaaa-0000-0000-0000-000000000000"
|
||||
assert not _block_matches(other, self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_match(self):
|
||||
assert _block_matches("c069dc6b", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_case_insensitive(self):
|
||||
assert _block_matches("C069DC6B", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_no_match(self):
|
||||
assert not _block_matches("deadbeef", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_match(self):
|
||||
assert _block_matches("HTTP Request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_case_insensitive(self):
|
||||
assert _block_matches("http request", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
assert _block_matches("HTTP REQUEST", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_name_no_match(self):
|
||||
assert not _block_matches("Unknown Block", self.BLOCK_ID, self.BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_not_matching_as_name(self):
|
||||
# "c069dc6b" is 8 hex chars → treated as partial UUID, NOT name match
|
||||
assert not _block_matches(
|
||||
"c069dc6b", "ffffffff-0000-0000-0000-000000000000", "c069dc6b"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.effective_allowed_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ALL_TOOLS = frozenset(
|
||||
["run_block", "web_fetch", "bash_exec", "find_agent", "Task", "Read"]
|
||||
)
|
||||
|
||||
|
||||
class TestEffectiveAllowedTools:
|
||||
def test_empty_list_allows_all(self):
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
# edge: tools_exclude=False but empty list → allow all
|
||||
perms = CopilotPermissions(tools=[], tools_exclude=False)
|
||||
assert perms.effective_allowed_tools(ALL_TOOLS) == ALL_TOOLS
|
||||
|
||||
def test_blacklist_removes_listed(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec", "web_fetch"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert "bash_exec" not in result
|
||||
assert "web_fetch" not in result
|
||||
assert "run_block" in result
|
||||
assert "Task" in result
|
||||
|
||||
def test_whitelist_keeps_only_listed(self):
|
||||
perms = CopilotPermissions(tools=["run_block", "Task"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset(["run_block", "Task"])
|
||||
|
||||
def test_whitelist_unknown_tool_yields_empty(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=False)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == frozenset()
|
||||
|
||||
def test_blacklist_unknown_tool_ignored(self):
|
||||
perms = CopilotPermissions(tools=["nonexistent"], tools_exclude=True)
|
||||
result = perms.effective_allowed_tools(ALL_TOOLS)
|
||||
assert result == ALL_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_block_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
BLOCK_NAME = "HTTP Request"
|
||||
|
||||
|
||||
class TestIsBlockAllowed:
|
||||
def test_empty_allows_everything(self):
|
||||
perms = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_blocks_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_blacklist_allows_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=True)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_allows_listed(self):
|
||||
perms = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_whitelist_blocks_unlisted(self):
|
||||
perms = CopilotPermissions(blocks=["Other Block"], blocks_exclude=False)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_partial_uuid_blacklist(self):
|
||||
perms = CopilotPermissions(blocks=["c069dc6b"], blocks_exclude=True)
|
||||
assert not perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_full_uuid_whitelist(self):
|
||||
perms = CopilotPermissions(blocks=[BLOCK_ID], blocks_exclude=False)
|
||||
assert perms.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_blocks_when_child_allows(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_parent_allows_when_child_blocks(self):
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_both_must_allow(self):
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=False)
|
||||
child._parent = parent
|
||||
assert child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
def test_grandparent_blocks_propagate(self):
|
||||
grandparent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
parent = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
parent._parent = grandparent
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
child._parent = parent
|
||||
assert not child.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.merged_with_parent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedWithParent:
|
||||
def test_tool_intersection(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_parent_whitelist_narrows_child(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True) # allow all
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert effective == frozenset(["run_block"])
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch", "bash_exec"])
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
# bash_exec was not in parent's whitelist → must not appear
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
|
||||
def test_merged_stored_as_whitelist(self):
|
||||
all_t = frozenset(["run_block", "web_fetch"])
|
||||
parent = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
child = CopilotPermissions(tools=[], tools_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
assert not merged.tools_exclude # stored as whitelist
|
||||
assert set(merged.tools) == {"run_block", "web_fetch"}
|
||||
|
||||
def test_block_parent_stored(self):
|
||||
all_t = frozenset(["run_block"])
|
||||
parent = CopilotPermissions(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
child = CopilotPermissions(blocks=[], blocks_exclude=True)
|
||||
merged = child.merged_with_parent(parent, all_t)
|
||||
# Parent restriction is preserved via _parent
|
||||
assert not merged.is_block_allowed(BLOCK_ID, BLOCK_NAME)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CopilotPermissions.is_empty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_default_is_empty(self):
|
||||
assert CopilotPermissions().is_empty()
|
||||
|
||||
def test_with_tools_not_empty(self):
|
||||
assert not CopilotPermissions(tools=["bash_exec"]).is_empty()
|
||||
|
||||
def test_with_blocks_not_empty(self):
|
||||
assert not CopilotPermissions(blocks=["HTTP Request"]).is_empty()
|
||||
|
||||
def test_with_parent_not_empty(self):
|
||||
perms = CopilotPermissions()
|
||||
perms._parent = CopilotPermissions(tools=["bash_exec"])
|
||||
assert not perms.is_empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_tool_names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateToolNames:
|
||||
def test_valid_registry_tool(self):
|
||||
assert validate_tool_names(["run_block", "web_fetch"]) == []
|
||||
|
||||
def test_valid_sdk_builtin(self):
|
||||
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
|
||||
|
||||
def test_invalid_tool(self):
|
||||
result = validate_tool_names(["nonexistent_tool"])
|
||||
assert "nonexistent_tool" in result
|
||||
|
||||
def test_mixed(self):
|
||||
result = validate_tool_names(["run_block", "fake_tool"])
|
||||
assert "fake_tool" in result
|
||||
assert "run_block" not in result
|
||||
|
||||
def test_empty_list(self):
|
||||
assert validate_tool_names([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_block_identifiers (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateBlockIdentifiers:
|
||||
async def test_empty_list(self):
|
||||
result = await validate_block_identifiers([])
|
||||
assert result == []
|
||||
|
||||
async def test_valid_full_uuid(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(
|
||||
["c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"]
|
||||
)
|
||||
assert result == []
|
||||
|
||||
async def test_invalid_identifier(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["totally_unknown"])
|
||||
assert "totally_unknown" in result
|
||||
|
||||
async def test_partial_uuid_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["c069dc6b"])
|
||||
assert result == []
|
||||
|
||||
async def test_name_match(self, mocker):
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.return_value.name = "HTTP Request"
|
||||
mocker.patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block},
|
||||
)
|
||||
result = await validate_block_identifiers(["http request"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# apply_tool_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplyToolPermissions:
|
||||
def test_empty_permissions_returns_base_unchanged(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=["mcp__copilot__run_block", "mcp__copilot__web_fetch", "Task"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
perms = CopilotPermissions()
|
||||
allowed, disallowed = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" in allowed
|
||||
|
||||
def test_blacklist_removes_tool(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"mcp__copilot__bash_exec",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{
|
||||
"run_block": object(),
|
||||
"web_fetch": object(),
|
||||
"bash_exec": object(),
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "bash_exec", "Task"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__bash_exec" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_whitelist_keeps_only_listed(self, mocker):
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__web_fetch",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object(), "web_fetch": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "web_fetch", "Task", "WebSearch"]),
|
||||
)
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "mcp__copilot__web_fetch" not in allowed
|
||||
assert "Task" not in allowed
|
||||
|
||||
def test_read_tool_always_included_even_when_blacklisted(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even if Read is explicitly blacklisted."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Explicitly blacklist Read
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
assert "Task" in allowed
|
||||
|
||||
def test_read_tool_always_included_with_narrow_whitelist(self, mocker):
|
||||
"""mcp__copilot__Read must stay in allowed even when not in a whitelist."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
# Whitelist only run_block — Read not listed
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=False)
|
||||
assert "mcp__copilot__Read" in allowed # always preserved for SDK internals
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
|
||||
def test_e2b_file_tools_included_when_sdk_builtin_whitelisted(self, mocker):
|
||||
"""In E2B mode, whitelisting 'Read' must include mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"mcp__copilot__write_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Write", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Whitelist Read and run_block — E2B read_file should be included
|
||||
perms = CopilotPermissions(tools=["Read", "run_block"], tools_exclude=False)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# Write not whitelisted — write_file should NOT be included
|
||||
assert "mcp__copilot__write_file" not in allowed
|
||||
|
||||
def test_e2b_file_tools_excluded_when_sdk_builtin_blacklisted(self, mocker):
|
||||
"""In E2B mode, blacklisting 'Read' must also remove mcp__copilot__read_file."""
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_copilot_tool_names",
|
||||
return_value=[
|
||||
"mcp__copilot__run_block",
|
||||
"mcp__copilot__Read",
|
||||
"mcp__copilot__read_file",
|
||||
"Task",
|
||||
],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools",
|
||||
return_value=["Bash", "Read", "Write", "Edit", "Glob", "Grep"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": object()},
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset(["run_block", "Read", "Task"]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.sdk.e2b_file_tools.E2B_FILE_TOOL_NAMES",
|
||||
["read_file", "write_file", "edit_file", "glob", "grep"],
|
||||
)
|
||||
# Blacklist Read — E2B read_file should also be removed
|
||||
perms = CopilotPermissions(tools=["Read"], tools_exclude=True)
|
||||
allowed, _ = apply_tool_permissions(perms, use_e2b=True)
|
||||
assert "mcp__copilot__read_file" not in allowed
|
||||
assert "mcp__copilot__run_block" in allowed
|
||||
# mcp__copilot__Read is always preserved for SDK internals
|
||||
assert "mcp__copilot__Read" in allowed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDK_BUILTIN_TOOL_NAMES sanity check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
registry_keys = frozenset(TOOL_REGISTRY.keys())
|
||||
assert PLATFORM_TOOL_NAMES == registry_keys, (
|
||||
f"ToolName Literal is out of sync with TOOL_REGISTRY. "
|
||||
f"Missing: {registry_keys - PLATFORM_TOOL_NAMES}, "
|
||||
f"Extra: {PLATFORM_TOOL_NAMES - registry_keys}"
|
||||
)
|
||||
|
||||
def test_all_tool_names_is_union(self):
|
||||
"""ALL_TOOL_NAMES must equal PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES."""
|
||||
assert ALL_TOOL_NAMES == PLATFORM_TOOL_NAMES | SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_no_overlap_between_platform_and_sdk(self):
|
||||
"""Platform and SDK built-in names must not overlap."""
|
||||
assert PLATFORM_TOOL_NAMES.isdisjoint(SDK_BUILTIN_TOOL_NAMES)
|
||||
|
||||
def test_known_tools_includes_registry_and_builtins(self):
|
||||
known = all_known_tool_names()
|
||||
assert "run_block" in known
|
||||
assert "Read" in known
|
||||
assert "Task" in known
|
||||
@@ -63,6 +63,50 @@ Example — committing an image file to GitHub:
|
||||
}}
|
||||
```
|
||||
|
||||
### Writing large files — CRITICAL
|
||||
**Never write an entire large document in a single tool call.** When the
|
||||
content you want to write exceeds ~2000 words the tool call's output token
|
||||
limit will silently truncate the arguments, producing an empty `{{}}` input
|
||||
that fails repeatedly.
|
||||
|
||||
**Preferred: compose from file references.** If the data is already in
|
||||
files (tool outputs, workspace files), compose the report in one call
|
||||
using `@@agptfile:` references — the system expands them inline:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Research Report
|
||||
## Data from web research
|
||||
@@agptfile:/home/user/web_results.txt
|
||||
## Block execution output
|
||||
@@agptfile:workspace://<file_id>
|
||||
## Conclusion
|
||||
<brief synthesis>
|
||||
EOF
|
||||
```
|
||||
|
||||
**Fallback: write section-by-section.** When you must generate content
|
||||
from conversation context (no files to reference), split into multiple
|
||||
`bash_exec` calls — one section per call:
|
||||
|
||||
```bash
|
||||
cat > report.md << 'EOF'
|
||||
# Section 1
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
```bash
|
||||
cat >> report.md << 'EOF'
|
||||
# Section 2
|
||||
<content from your earlier tool call results>
|
||||
EOF
|
||||
```
|
||||
Use `cat >` for the first chunk and `cat >>` to append subsequent chunks.
|
||||
Do not re-fetch or re-generate data you already have from prior tool calls.
|
||||
|
||||
After building the file, reference it with `@@agptfile:` in other tools:
|
||||
`@@agptfile:/home/user/report.md`
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -161,9 +205,10 @@ Important files (code, configs, outputs) should be saved to workspace to ensure
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
@@ -61,6 +65,7 @@ async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -68,6 +73,7 @@ async def get_usage_status(
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -97,6 +103,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +148,110 @@ async def check_rate_limit(
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
"""Reset a user's daily token usage counter in Redis.
|
||||
|
||||
Called after a user pays credits to extend their daily limit.
|
||||
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
|
||||
(clamped to 0) so the user effectively gets one extra day's worth of
|
||||
weekly capacity.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
|
||||
# (weekly) either both execute or neither does. This prevents the
|
||||
# scenario where the daily counter is cleared but the weekly
|
||||
# counter is not decremented — which would let the caller refund
|
||||
# credits even though the daily limit was already reset.
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
|
||||
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.delete(d_key)
|
||||
if w_key is not None:
|
||||
pipe.decrby(w_key, daily_token_limit)
|
||||
results = await pipe.execute()
|
||||
|
||||
# Clamp negative weekly counter to 0 (best-effort; not critical).
|
||||
if w_key is not None:
|
||||
new_val = results[1] # DECRBY result
|
||||
if new_val < 0:
|
||||
await redis.set(w_key, 0, keepttl=True)
|
||||
|
||||
logger.info("Reset daily usage for user %s", user_id[:8])
|
||||
return True
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for resetting daily usage")
|
||||
return False
|
||||
|
||||
|
||||
_RESET_LOCK_PREFIX = "copilot:reset_lock"
|
||||
_RESET_COUNT_PREFIX = "copilot:reset_count"
|
||||
|
||||
|
||||
async def acquire_reset_lock(user_id: str, ttl_seconds: int = 10) -> bool:
|
||||
"""Acquire a short-lived lock to serialize rate limit resets per user."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_LOCK_PREFIX}:{user_id}"
|
||||
return bool(await redis.set(key, "1", nx=True, ex=ttl_seconds))
|
||||
except (RedisError, ConnectionError, OSError) as exc:
|
||||
logger.warning("Redis unavailable for reset lock, rejecting reset: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
async def release_reset_lock(user_id: str) -> None:
|
||||
"""Release the per-user reset lock."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(f"{_RESET_LOCK_PREFIX}:{user_id}")
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
pass # Lock will expire via TTL
|
||||
|
||||
|
||||
async def get_daily_reset_count(user_id: str) -> int | None:
|
||||
"""Get how many times the user has reset today.
|
||||
|
||||
Returns None when Redis is unavailable so callers can fail-closed
|
||||
for billed operations (as opposed to failing open for read-only
|
||||
rate-limit checks).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
val = await redis.get(key)
|
||||
return int(val or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for reading daily reset count")
|
||||
return None
|
||||
|
||||
|
||||
async def increment_daily_reset_count(user_id: str) -> None:
|
||||
"""Increment and track how many resets this user has done today."""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.incr(key)
|
||||
seconds_until_reset = int((_daily_reset_time(now=now) - now).total_seconds())
|
||||
pipe.expire(key, max(seconds_until_reset, 1))
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for tracking reset count")
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
@@ -231,6 +342,67 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
daily_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
|
||||
)
|
||||
weekly_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
|
||||
)
|
||||
try:
|
||||
daily = max(0, int(daily_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
|
||||
daily = config_daily
|
||||
try:
|
||||
weekly = max(0, int(weekly_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
"""Reset a user's usage counters.
|
||||
|
||||
Always deletes the daily Redis key. When *reset_weekly* is ``True``,
|
||||
the weekly key is deleted as well.
|
||||
|
||||
Unlike read paths (``get_usage_status``, ``check_rate_limit``) which
|
||||
fail-open on Redis errors, resets intentionally re-raise so the caller
|
||||
knows the operation did not succeed. A silent failure here would leave
|
||||
the admin believing the counters were zeroed when they were not.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
keys_to_delete = [_daily_key(user_id, now=now)]
|
||||
if reset_weekly:
|
||||
keys_to_delete.append(_weekly_key(user_id, now=now))
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(*keys_to_delete)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for resetting user usage")
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -12,6 +12,7 @@ from .rate_limit import (
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
reset_daily_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
@@ -332,3 +333,91 @@ class TestRecordTokenUsage:
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_daily_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetDailyUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
|
||||
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[1, decrby_result])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_daily_key(self):
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=0)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is True
|
||||
mock_pipe.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reduces_weekly_usage_via_decrby(self):
|
||||
"""Weekly counter should be reduced via DECRBY in the pipeline."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clamps_negative_weekly_to_zero(self):
|
||||
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_weekly_reduction_when_daily_limit_zero(self):
|
||||
"""When daily_token_limit is 0, weekly counter should not be touched."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=0)
|
||||
|
||||
mock_pipe.delete.assert_called_once()
|
||||
mock_pipe.decrby.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is False
|
||||
|
||||
294
autogpt_platform/backend/backend/copilot/reset_usage_test.py
Normal file
294
autogpt_platform/backend/backend/copilot/reset_usage_test.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Unit tests for the POST /usage/reset endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
# Minimal config mock matching ChatConfig fields used by the endpoint.
|
||||
def _make_config(
|
||||
rate_limit_reset_cost: int = 500,
|
||||
daily_token_limit: int = 2_500_000,
|
||||
weekly_token_limit: int = 12_500_000,
|
||||
max_daily_resets: int = 5,
|
||||
):
|
||||
cfg = MagicMock()
|
||||
cfg.rate_limit_reset_cost = rate_limit_reset_cost
|
||||
cfg.daily_token_limit = daily_token_limit
|
||||
cfg.weekly_token_limit = weekly_token_limit
|
||||
cfg.max_daily_resets = max_daily_resets
|
||||
return cfg
|
||||
|
||||
|
||||
def _usage(daily_used: int = 3_000_000, daily_limit: int = 2_500_000):
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_limit,
|
||||
resets_at=datetime.now(UTC) + timedelta(hours=6),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=5_000_000,
|
||||
limit=12_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(days=3),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_MODULE = "backend.api.features.chat.routes"
|
||||
|
||||
|
||||
def _mock_settings(enable_credit: bool = True):
|
||||
"""Return a mock Settings object with the given enable_credit flag."""
|
||||
mock = MagicMock()
|
||||
mock.config.enable_credit = enable_credit
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
"""When rate_limit_reset_cost=0, endpoint returns 400."""
|
||||
|
||||
with patch(f"{_MODULE}.config", _make_config(rate_limit_reset_cost=0)):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "not available" in exc_info.value.detail
|
||||
|
||||
async def test_no_daily_limit_returns_400(self):
|
||||
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "nothing to reset" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_not_at_limit_returns_400(self):
|
||||
"""When user hasn't hit their daily limit, returns 400."""
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage(daily_used=1_000_000)),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "not reached" in exc_info.value.detail
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_insufficient_credits_returns_402(self):
|
||||
"""When user doesn't have enough credits, returns 402."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.side_effect = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id="user-1",
|
||||
balance=50,
|
||||
amount=200,
|
||||
)
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 402
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_happy_path(self):
|
||||
"""Successful reset: charges credits, resets usage, returns response."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500 # remaining balance
|
||||
|
||||
cfg = _make_config()
|
||||
updated_usage = _usage(daily_used=0)
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(side_effect=[_usage(), updated_usage]),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=True)
|
||||
) as mock_reset,
|
||||
patch(f"{_MODULE}.increment_daily_reset_count", AsyncMock()) as mock_incr,
|
||||
):
|
||||
result = await reset_copilot_usage(user_id="user-1")
|
||||
assert result.success is True
|
||||
assert result.credits_charged == 500
|
||||
assert result.remaining_balance == 1500
|
||||
mock_reset.assert_awaited_once()
|
||||
mock_incr.assert_awaited_once()
|
||||
|
||||
async def test_max_daily_resets_exceeded(self):
|
||||
"""When user has exhausted daily resets, returns 429."""
|
||||
|
||||
cfg = _make_config(max_daily_resets=3)
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
async def test_credit_system_disabled_returns_400(self):
|
||||
"""When enable_credit=False, endpoint returns 400."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings(enable_credit=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "credit system is disabled" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_weekly_limit_exhausted_returns_400(self):
|
||||
"""When the weekly limit is also exhausted, resetting daily won't help."""
|
||||
|
||||
cfg = _make_config()
|
||||
weekly_exhausted = CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=3_000_000,
|
||||
limit=2_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(hours=6),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=12_500_000,
|
||||
limit=12_500_000,
|
||||
resets_at=datetime.now(UTC) + timedelta(days=3),
|
||||
),
|
||||
)
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=weekly_exhausted),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "weekly" in exc_info.value.detail.lower()
|
||||
mock_release.assert_awaited_once()
|
||||
|
||||
async def test_redis_failure_for_reset_count_returns_503(self):
|
||||
"""When Redis is unavailable for get_daily_reset_count, returns 503."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "verify" in exc_info.value.detail.lower()
|
||||
|
||||
async def test_redis_reset_failure_refunds_credits(self):
|
||||
"""When reset_daily_usage fails, credits are refunded and 503 returned."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "not been charged" in exc_info.value.detail
|
||||
mock_credit_model.top_up_credits.assert_awaited_once()
|
||||
|
||||
async def test_redis_reset_failure_refund_also_fails(self):
|
||||
"""When both reset and refund fail, error message reflects the truth."""
|
||||
|
||||
mock_credit_model = AsyncMock()
|
||||
mock_credit_model.spend_credits.return_value = 1500
|
||||
mock_credit_model.top_up_credits.side_effect = RuntimeError("db down")
|
||||
|
||||
cfg = _make_config()
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
patch(
|
||||
f"{_MODULE}.get_usage_status",
|
||||
AsyncMock(return_value=_usage()),
|
||||
),
|
||||
patch(
|
||||
f"{_MODULE}.get_user_credit_model",
|
||||
AsyncMock(return_value=mock_credit_model),
|
||||
),
|
||||
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "contact support" in exc_info.value.detail.lower()
|
||||
@@ -67,9 +67,17 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Optional: `title`, `description`
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
- For dropdown/select inputs, use **AgentDropdownInputBlock** instead (see below)
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
@@ -143,11 +151,11 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using SmartDecisionMakerBlock (AI Orchestrator with Agent Mode)
|
||||
### Using OrchestratorBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `SmartDecisionMakerBlock` node
|
||||
1. Create a `OrchestratorBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
@@ -169,8 +177,8 @@ call in a loop until the task is complete:
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the SmartDecisionMaker: set `source_name: "tools"` on
|
||||
the SmartDecisionMaker side and `sink_name: <input_field>` on each tool
|
||||
5. Link each tool to the Orchestrator: set `source_name: "tools"` on
|
||||
the Orchestrator side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
@@ -178,35 +186,49 @@ call in a loop until the task is complete:
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→Agent A (per input field): `source_name: "tools"`,
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- SDM→Agent B (per input field): `source_name: "tools"`,
|
||||
- Orchestrator→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
- Node 2: `OrchestratorBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- SDM→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
- Input→Orchestrator: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- Orchestrator→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- Orchestrator→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- Orchestrator→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the SmartDecisionMaker side.
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -7,7 +7,35 @@ without implementing their own event loop.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Identifiers used when registering AutoPilot-originated streams in the
|
||||
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
|
||||
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
|
||||
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
|
||||
AUTOPILOT_TOOL_NAME = "autopilot"
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
@@ -33,26 +61,131 @@ class CopilotResult:
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
class _RegistryHandle(BaseModel):
|
||||
"""Tracks stream registry session state for cleanup."""
|
||||
|
||||
publish_turn_id: str = ""
|
||||
error_msg: str | None = None
|
||||
error_already_published: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _registry_session(
|
||||
session_id: str, user_id: str, turn_id: str
|
||||
) -> AsyncIterator[_RegistryHandle]:
|
||||
"""Create a stream registry session and ensure it is finalized."""
|
||||
handle = _RegistryHandle(publish_turn_id=turn_id)
|
||||
try:
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=AUTOPILOT_TOOL_NAME,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to create stream registry session for %s, "
|
||||
"frontend will not receive real-time updates",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
# Disable chunk publishing but keep finalization enabled so
|
||||
# mark_session_completed can clean up any partial registry state.
|
||||
handle.publish_turn_id = ""
|
||||
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
session_id,
|
||||
error_message=handle.error_msg,
|
||||
skip_error_publish=handle.error_already_published,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to mark stream completed for %s",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallEntry(BaseModel):
|
||||
"""A single tool call observed during stream consumption."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any = None
|
||||
success: bool | None = None
|
||||
|
||||
|
||||
class _EventAccumulator(BaseModel):
|
||||
"""Mutable accumulator for stream events."""
|
||||
|
||||
response_parts: list[str] = Field(default_factory=list)
|
||||
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
|
||||
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
|
||||
"""Process a single stream event and return error_msg if StreamError.
|
||||
|
||||
Uses structural pattern matching for dispatch per project guidelines.
|
||||
"""
|
||||
match event:
|
||||
case StreamTextDelta(delta=delta):
|
||||
acc.response_parts.append(delta)
|
||||
case StreamToolInputAvailable() as e:
|
||||
entry = _ToolCallEntry(
|
||||
tool_call_id=e.toolCallId,
|
||||
tool_name=e.toolName,
|
||||
input=e.input,
|
||||
)
|
||||
acc.tool_calls.append(entry)
|
||||
acc.tool_calls_by_id[e.toolCallId] = entry
|
||||
case StreamToolOutputAvailable() as e:
|
||||
if tc := acc.tool_calls_by_id.get(e.toolCallId):
|
||||
tc.output = e.output
|
||||
tc.success = e.success
|
||||
else:
|
||||
logger.debug(
|
||||
"Received tool output for unknown tool_call_id: %s",
|
||||
e.toolCallId,
|
||||
)
|
||||
case StreamUsage() as e:
|
||||
acc.prompt_tokens += e.prompt_tokens
|
||||
acc.completion_tokens += e.completion_tokens
|
||||
acc.total_tokens += e.total_tokens
|
||||
case StreamError(errorText=err):
|
||||
return err
|
||||
return None
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
This is the recommended entry-point for callers that need a simple
|
||||
request-response interface (e.g. the AutoPilot block) rather than
|
||||
streaming individual events. It avoids duplicating the event-collection
|
||||
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
|
||||
manages its own heartbeat-based timeouts internally.
|
||||
Registers with the stream registry so the frontend can connect via SSE
|
||||
and receive real-time updates while the AutoPilot block is executing.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
permissions: Optional capability filter. When provided, restricts
|
||||
which tools and blocks the copilot may use during this execution.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
@@ -61,48 +194,39 @@ async def collect_copilot_response(
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
turn_id = str(uuid.uuid4())
|
||||
async with _registry_session(session_id, user_id, turn_id) as handle:
|
||||
try:
|
||||
raw_stream = stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
published_stream = stream_registry.stream_and_publish(
|
||||
session_id=session_id,
|
||||
turn_id=handle.publish_turn_id,
|
||||
stream=raw_stream,
|
||||
)
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
acc = _EventAccumulator()
|
||||
async for event in published_stream:
|
||||
if err := _process_event(event, acc):
|
||||
handle.error_msg = err
|
||||
# stream_and_publish skips StreamError events, so
|
||||
# mark_session_completed must publish the error to Redis.
|
||||
handle.error_already_published = False
|
||||
raise RuntimeError(f"Copilot error: {err}")
|
||||
except Exception:
|
||||
if handle.error_msg is None:
|
||||
handle.error_msg = "AutoPilot execution failed"
|
||||
raise
|
||||
|
||||
result = CopilotResult()
|
||||
response_parts: list[str] = []
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
response_parts.append(event.delta)
|
||||
elif isinstance(event, StreamToolInputAvailable):
|
||||
entry: dict[str, Any] = {
|
||||
"tool_call_id": event.toolCallId,
|
||||
"tool_name": event.toolName,
|
||||
"input": event.input,
|
||||
"output": None,
|
||||
"success": None,
|
||||
}
|
||||
result.tool_calls.append(entry)
|
||||
tool_calls_by_id[event.toolCallId] = entry
|
||||
elif isinstance(event, StreamToolOutputAvailable):
|
||||
if tc := tool_calls_by_id.get(event.toolCallId):
|
||||
tc["output"] = event.output
|
||||
tc["success"] = event.success
|
||||
elif isinstance(event, StreamUsage):
|
||||
result.prompt_tokens += event.prompt_tokens
|
||||
result.completion_tokens += event.completion_tokens
|
||||
result.total_tokens += event.total_tokens
|
||||
elif isinstance(event, StreamError):
|
||||
raise RuntimeError(f"Copilot error: {event.errorText}")
|
||||
|
||||
result.response_text = "".join(response_parts)
|
||||
result.response_text = "".join(acc.response_parts)
|
||||
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
|
||||
result.prompt_tokens = acc.prompt_tokens
|
||||
result.completion_tokens = acc.completion_tokens
|
||||
result.total_tokens = acc.total_tokens
|
||||
return result
|
||||
|
||||
177
autogpt_platform/backend/backend/copilot/sdk/collect_test.py
Normal file
177
autogpt_platform/backend/backend/copilot/sdk/collect_test.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Tests for collect_copilot_response stream registry integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
|
||||
def _mock_stream_fn(*events):
|
||||
"""Return a callable that returns an async generator."""
|
||||
|
||||
async def _gen(**_kwargs):
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
return _gen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Patch stream_registry module used by collect."""
|
||||
with patch("backend.copilot.sdk.collect.stream_registry") as m:
|
||||
m.create_session = AsyncMock()
|
||||
m.publish_chunk = AsyncMock()
|
||||
m.mark_session_completed = AsyncMock()
|
||||
|
||||
# stream_and_publish: pass-through that also publishes (real logic)
|
||||
# We re-implement the pass-through here so the event loop works,
|
||||
# but still track publish_chunk calls via the mock.
|
||||
async def _stream_and_publish(session_id, turn_id, stream):
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
await m.publish_chunk(turn_id, event)
|
||||
yield event
|
||||
|
||||
m.stream_and_publish = _stream_and_publish
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_fn_patch():
|
||||
"""Helper to patch stream_chat_completion_sdk."""
|
||||
|
||||
def _patch(events):
|
||||
return patch(
|
||||
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
|
||||
new=_mock_stream_fn(*events),
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
|
||||
"""Stream registry create/publish/complete are called correctly on success."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="Hello "),
|
||||
StreamTextDelta(id="t1", delta="world"),
|
||||
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "Hello world"
|
||||
assert result.total_tokens == 15
|
||||
|
||||
mock_registry.create_session.assert_awaited_once()
|
||||
# StreamFinish should NOT be published (mark_session_completed does it)
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamFinish" not in published_types
|
||||
assert "StreamTextDelta" in published_types
|
||||
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
|
||||
"""mark_session_completed receives error message when StreamError occurs."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="partial"),
|
||||
StreamError(errorText="something broke"),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
with pytest.raises(RuntimeError, match="something broke"):
|
||||
await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") == "something broke"
|
||||
# stream_and_publish skips StreamError, so mark_session_completed must
|
||||
# publish it (skip_error_publish=False).
|
||||
assert kwargs.get("skip_error_publish") is False
|
||||
|
||||
# StreamError should NOT be published via publish_chunk — mark_session_completed
|
||||
# handles it to avoid double-publication.
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamError" not in published_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_create_session_fails(
|
||||
mock_registry, stream_fn_patch
|
||||
):
|
||||
"""AutoPilot still works when stream registry create_session raises."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="works"),
|
||||
StreamFinish(),
|
||||
]
|
||||
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "works"
|
||||
# publish_chunk should NOT be called because turn_id was cleared
|
||||
mock_registry.publish_chunk.assert_not_awaited()
|
||||
# mark_session_completed IS still called to clean up any partial state
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
|
||||
"""Tool call events are both published to registry and collected in result."""
|
||||
events = [
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
|
||||
),
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId="tc-1", output="file contents", success=True
|
||||
),
|
||||
StreamTextDelta(id="t1", delta="done"),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["tool_name"] == "read_file"
|
||||
assert result.tool_calls[0]["output"] == "file contents"
|
||||
assert result.tool_calls[0]["success"] is True
|
||||
assert result.response_text == "done"
|
||||
@@ -25,24 +25,64 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
|
||||
Delegates to ``build_structured_transcript`` — plain content strings
|
||||
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
|
||||
assistant messages.
|
||||
"""
|
||||
# Cast widening: tuple[str, str] is structurally compatible with
|
||||
# tuple[str, str | list[dict]] but list invariance requires explicit
|
||||
# annotation.
|
||||
widened: list[tuple[str, str | list[dict]]] = list(pairs)
|
||||
return build_structured_transcript(widened)
|
||||
|
||||
|
||||
def build_structured_transcript(
|
||||
entries: list[tuple[str, str | list[dict]]],
|
||||
) -> str:
|
||||
"""Build a JSONL transcript with structured content blocks.
|
||||
|
||||
Each entry is (role, content) where content is either a plain string
|
||||
(for user messages) or a list of content block dicts (for assistant
|
||||
messages with thinking/tool_use/text blocks).
|
||||
|
||||
Example::
|
||||
|
||||
build_structured_transcript([
|
||||
("user", "Hello"),
|
||||
("assistant", [
|
||||
{"type": "thinking", "thinking": "...", "signature": "sig1"},
|
||||
{"type": "text", "text": "Hi there"},
|
||||
]),
|
||||
])
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
for role, content in entries:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
msg: dict = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": content,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
elif role == "assistant":
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
msg = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
and ``/tmp`` filesystems as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
@@ -16,10 +16,13 @@ import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_ALLOWED_DIRS,
|
||||
E2B_ALLOWED_DIRS_STR,
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
is_within_allowed_dirs,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
@@ -36,7 +39,7 @@ async def _check_sandbox_symlink_escape(
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
the allowed sandbox directories.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
@@ -52,10 +55,7 @@ async def _check_sandbox_symlink_escape(
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
or not is_within_allowed_dirs(canonical_parent)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
@@ -89,6 +89,38 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
(``sandbox.commands.run``) run as **different users**. On ``/tmp``
|
||||
(which has the sticky bit set) this means ``sandbox.files.write`` can
|
||||
create new files but cannot overwrite files previously created by
|
||||
``sandbox.commands.run`` (or itself), because the sticky bit restricts
|
||||
deletion/rename to the file owner.
|
||||
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=10,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"shell write failed (exit {result.exit_code}): "
|
||||
+ (result.stderr or "").strip()
|
||||
)
|
||||
else:
|
||||
await sandbox.files.write(path, content)
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
@@ -139,13 +171,16 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
if parent and parent not in E2B_ALLOWED_DIRS:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
await _sandbox_write(sandbox, remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -172,7 +207,10 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
return _mcp(
|
||||
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
|
||||
error=True,
|
||||
)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
@@ -197,7 +235,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
await _sandbox_write(sandbox, remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
@@ -290,14 +328,14 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Read a file from the cloud sandbox (/home/user or /tmp). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
@@ -314,7 +352,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Write or create a file in the cloud sandbox (/home/user or /tmp). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
@@ -323,7 +361,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
@@ -340,7 +378,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_proj
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
@@ -39,23 +40,23 @@ class TestResolveSandboxPath:
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
@@ -68,6 +69,24 @@ class TestResolveSandboxPath:
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
|
||||
def test_tmp_absolute_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
|
||||
|
||||
def test_tmp_nested_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
|
||||
|
||||
def test_tmp_itself_allowed(self):
|
||||
assert resolve_sandbox_path("/tmp") == "/tmp"
|
||||
|
||||
def test_tmp_escape_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp/../etc/passwd")
|
||||
|
||||
def test_tmp_prefix_collision_blocked(self):
|
||||
"""A path like /tmp_evil should be blocked (not a prefix match)."""
|
||||
with pytest.raises(ValueError, match="must be within"):
|
||||
resolve_sandbox_path("/tmp_evil/malicious.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
@@ -227,3 +246,92 @@ class TestCheckSandboxSymlinkEscape:
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_allowed(self):
|
||||
"""Paths resolving to /tmp are allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp/workdir\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp/workdir")
|
||||
assert result == "/tmp/workdir"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_itself_allowed(self):
|
||||
"""The /tmp directory itself is allowed."""
|
||||
sandbox = _make_sandbox(stdout="/tmp\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, "/tmp")
|
||||
assert result == "/tmp"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sandbox_write — routing writes through shell for /tmp paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSandboxWrite:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_path_uses_shell_command(self):
|
||||
"""Writes to /tmp should use commands.run (shell) instead of files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/test.py", "print('hello')")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_user_path_uses_files_api(self):
|
||||
"""Writes to /home/user should use sandbox.files.write."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/home/user/test.py", "print('hello')")
|
||||
|
||||
files.write.assert_called_once_with("/home/user/test.py", "print('hello')")
|
||||
commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_nested_path_uses_shell_command(self):
|
||||
"""Writes to nested /tmp paths should use commands.run."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
sandbox = SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
await _sandbox_write(sandbox, "/tmp/subdir/file.txt", "content")
|
||||
|
||||
commands.run.assert_called_once()
|
||||
files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_shell_failure_raises(self):
|
||||
"""Shell write failure should raise RuntimeError."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="No space left", exit_code=1)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
with pytest.raises(RuntimeError, match="shell write failed"):
|
||||
await _sandbox_write(sandbox, "/tmp/test.txt", "content")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_write_preserves_content_with_special_chars(self):
|
||||
"""Content with special shell characters should be preserved via base64."""
|
||||
import base64
|
||||
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
sandbox = SimpleNamespace(commands=commands)
|
||||
|
||||
content = "print(\"Hello $USER\")\n# a `backtick` and 'quotes'\n"
|
||||
await _sandbox_write(sandbox, "/tmp/special.py", content)
|
||||
|
||||
# Verify the command contains base64-encoded content
|
||||
call_args = commands.run.call_args[0][0]
|
||||
# Extract the base64 string from the command
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""SDK environment variable builder — importable without circular deps.
|
||||
|
||||
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
|
||||
can reuse the same subscription / OpenRouter / direct-Anthropic logic
|
||||
without pulling in the full copilot service module (which would create a
|
||||
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.sdk.subscription import validate_subscription
|
||||
|
||||
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
|
||||
# A singleton would require importing service.py which causes the circular dep
|
||||
# this module was created to avoid.
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — build a ChatConfig with explicit field values so tests don't
|
||||
# depend on real environment variables.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 1 — Subscription auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvSubscription:
|
||||
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_returns_blanked_keys(self, mock_validate):
|
||||
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"backend.copilot.sdk.env.validate_subscription",
|
||||
side_effect=RuntimeError("CLI not found"),
|
||||
)
|
||||
def test_propagates_validation_error(self, mock_validate):
|
||||
"""If validate_subscription fails, the error bubbles up."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
with pytest.raises(RuntimeError, match="CLI not found"):
|
||||
build_sdk_env()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 2 — Direct Anthropic (no OpenRouter)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvDirectAnthropic:
|
||||
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_inactive(self):
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
|
||||
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
|
||||
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
|
||||
# Force api_key to None after construction (field_validator may pick up env vars)
|
||||
object.__setattr__(cfg, "api_key", None)
|
||||
assert not cfg.openrouter_active
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 3 — OpenRouter proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvOpenRouter:
|
||||
"""When OpenRouter is active, return proxy env vars."""
|
||||
|
||||
def _openrouter_config(self, **overrides):
|
||||
defaults = {
|
||||
"use_openrouter": True,
|
||||
"api_key": "sk-or-test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return _make_config(**defaults)
|
||||
|
||||
def test_basic_openrouter_env(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
|
||||
|
||||
def test_strips_trailing_v1(self):
|
||||
"""The /v1 suffix is stripped from the base URL."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_strips_trailing_v1_and_slash(self):
|
||||
"""Trailing slash before /v1 strip is handled."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# rstrip("/") first, then remove /v1
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_no_v1_suffix_left_alone(self):
|
||||
"""A base URL without /v1 is used as-is."""
|
||||
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
|
||||
|
||||
def test_session_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="sess-123")
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" in result
|
||||
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_user_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(user_id="user-456")
|
||||
|
||||
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_both_headers(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="s1", user_id="u2")
|
||||
|
||||
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-session-id: s1" in headers
|
||||
assert "x-user-id: u2" in headers
|
||||
# They should be newline-separated
|
||||
assert "\n" in headers
|
||||
|
||||
def test_header_sanitisation_strips_newlines(self):
|
||||
"""Newlines/carriage-returns in header values are stripped."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="bad\r\nvalue")
|
||||
|
||||
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
# The _safe helper removes \r and \n
|
||||
assert "\r" not in header_val.split(": ", 1)[1]
|
||||
assert "badvalue" in header_val
|
||||
|
||||
def test_header_value_truncated_to_128_chars(self):
|
||||
"""Header values are truncated to 128 characters."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
long_id = "x" * 200
|
||||
result = build_sdk_env(session_id=long_id)
|
||||
|
||||
# The value after "x-session-id: " should be at most 128 chars
|
||||
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
value = header_line.split(": ", 1)[1]
|
||||
assert len(value) == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvModePriority:
|
||||
"""Subscription mode takes precedence over OpenRouter."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_subscription_overrides_openrouter(self, mock_validate):
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=True,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# Should get subscription result, not OpenRouter
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
@@ -442,8 +442,11 @@ class TestCompactTranscript:
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
# The last assistant entry is preserved verbatim from original
|
||||
assert msgs[2]["content"] == "Details"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
|
||||
@@ -15,6 +15,7 @@ from claude_agent_sdk import (
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
@@ -100,6 +101,11 @@ class SDKResponseAdapter:
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
# Thinking blocks are preserved in the transcript but
|
||||
# not streamed to the frontend — skip silently.
|
||||
pass
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
|
||||
@@ -124,8 +124,11 @@ class TestScenarioCompactAndRetry:
|
||||
assert result != original # Must be different
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[0]["content"] == "[summary of conversation]"
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Long answer 2"
|
||||
|
||||
def test_compacted_transcript_loads_into_builder(self):
|
||||
"""TranscriptBuilder can load a compacted transcript and continue."""
|
||||
@@ -737,7 +740,10 @@ class TestRetryEdgeCases:
|
||||
assert result is not None
|
||||
assert result != transcript
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Answer 19"
|
||||
|
||||
def test_messages_to_transcript_roundtrip_preserves_content(self):
|
||||
"""Verify messages → transcript → messages preserves all content."""
|
||||
@@ -1004,7 +1010,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
|
||||
@@ -12,7 +12,10 @@ import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple, cast
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
@@ -29,6 +32,7 @@ from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -73,13 +77,17 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
reset_tool_failure_counters,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
@@ -105,6 +113,20 @@ config = ChatConfig()
|
||||
# Non-context errors (network, auth, rate-limit) are NOT retried.
|
||||
_MAX_STREAM_ATTEMPTS = 3
|
||||
|
||||
# Hard circuit breaker: abort the stream if the model sends this many
|
||||
# consecutive tool calls with empty parameters (a sign of context
|
||||
# saturation or serialization failure). Empty input ({}) is never
|
||||
# legitimate — even one is suspicious, three is conclusive.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 3
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"AutoPilot was unable to complete the tool call "
|
||||
"— this usually happens when the response is "
|
||||
"too large to fit in a single tool call. "
|
||||
"Try breaking your request into smaller parts."
|
||||
)
|
||||
|
||||
# Patterns that indicate the prompt/request exceeds the model's context limit.
|
||||
# Matched case-insensitively against the full exception chain.
|
||||
_PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
|
||||
@@ -163,6 +185,37 @@ def _is_prompt_too_long(err: BaseException) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
|
||||
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
|
||||
|
||||
Two known patterns occur when ``GeneratorExit`` tears down the async
|
||||
generator and the SDK's ``__aexit__`` runs in a different context/task:
|
||||
|
||||
* ``RuntimeError``: cancel scope exited in wrong task (anyio)
|
||||
* ``ValueError``: ContextVar token created in a different Context (OTEL)
|
||||
|
||||
These are suppressed to avoid polluting Sentry with non-actionable noise.
|
||||
"""
|
||||
if isinstance(exc, RuntimeError) and "cancel scope" in str(exc):
|
||||
return True
|
||||
if isinstance(exc, ValueError) and "was created in a different Context" in str(exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_tool_only_message(sdk_msg: object) -> bool:
|
||||
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
|
||||
|
||||
Such a message represents a parallel tool-call batch (no text output yet).
|
||||
The ``bool(…content)`` guard prevents vacuous-truth evaluation on an empty list.
|
||||
"""
|
||||
return (
|
||||
isinstance(sdk_msg, AssistantMessage)
|
||||
and bool(sdk_msg.content)
|
||||
and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
)
|
||||
|
||||
|
||||
class ReducedContext(NamedTuple):
|
||||
builder: TranscriptBuilder
|
||||
use_resume: bool
|
||||
@@ -374,6 +427,63 @@ _HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
|
||||
async def _safe_close_sdk_client(
|
||||
sdk_client: ClaudeSDKClient,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Close a ClaudeSDKClient, suppressing errors from client disconnect.
|
||||
|
||||
When the SSE client disconnects mid-stream, ``GeneratorExit`` propagates
|
||||
through the async generator stack and causes ``ClaudeSDKClient.__aexit__``
|
||||
to run in a different async context or task than where the client was
|
||||
opened. This triggers two known error classes:
|
||||
|
||||
* ``ValueError``: ``<Token var=<ContextVar name='current_context'>>
|
||||
was created in a different Context`` — OpenTelemetry's
|
||||
``context.detach()`` fails because the OTEL context token was
|
||||
created in the original generator coroutine but detach runs in
|
||||
the GC / cleanup coroutine (Sentry: AUTOGPT-SERVER-8BT).
|
||||
|
||||
* ``RuntimeError``: ``Attempted to exit cancel scope in a different
|
||||
task than it was entered in`` — anyio's ``TaskGroup.__aexit__``
|
||||
detects that the cancel scope was entered in one task but is
|
||||
being exited in another (Sentry: AUTOGPT-SERVER-8BW).
|
||||
|
||||
Both are harmless — the TCP connection is already dead and no
|
||||
resources leak. Logging them at ``debug`` level keeps observability
|
||||
without polluting Sentry.
|
||||
"""
|
||||
try:
|
||||
await sdk_client.__aexit__(None, None, None)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
if _is_sdk_disconnect_error(exc):
|
||||
# Expected during client disconnect — suppress to avoid Sentry noise.
|
||||
logger.debug(
|
||||
"%s SDK client cleanup error suppressed (client disconnect): %s: %s",
|
||||
log_prefix,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# GeneratorExit can propagate through __aexit__ — suppress it here
|
||||
# since the generator is already being torn down.
|
||||
logger.debug(
|
||||
"%s SDK client cleanup GeneratorExit suppressed (client disconnect)",
|
||||
log_prefix,
|
||||
)
|
||||
except Exception:
|
||||
# Unexpected cleanup error — log at error level so Sentry captures it
|
||||
# (via its logging integration), but don't propagate since we're in
|
||||
# teardown and the caller cannot meaningfully handle this.
|
||||
logger.error(
|
||||
"%s Unexpected SDK client cleanup error",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _iter_sdk_messages(
|
||||
client: ClaudeSDKClient,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
@@ -457,60 +567,6 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
|
||||
2. **Direct Anthropic** — returns `{}`; subprocess inherits
|
||||
`ANTHROPIC_API_KEY` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
_validate_claude_code_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
# `openrouter_active` checks the flag *and* credential presence.
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path.
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
"""Sanitise a header value: strip newlines/whitespace and cap length."""
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -560,7 +616,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
Unknown block types are logged and skipped.
|
||||
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
|
||||
a typed class for) are passed through verbatim to preserve them in the
|
||||
transcript. Unknown typed block objects are logged and skipped.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
@@ -592,6 +650,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, dict) and "type" in block:
|
||||
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
|
||||
result.append(block)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
@@ -996,15 +1057,122 @@ def _dispatch_response(
|
||||
return response
|
||||
|
||||
|
||||
class _TransientErrorHandled(Exception):
|
||||
class _HandledStreamError(Exception):
|
||||
"""Raised by `_run_stream_attempt` after it has already yielded a
|
||||
`StreamError` for a transient API error.
|
||||
`StreamError` to the client (e.g. transient API error, circuit breaker).
|
||||
|
||||
This signals the outer retry loop that the attempt failed so it can
|
||||
perform session-message rollback and set the `ended_with_stream_error`
|
||||
flag, **without** yielding a duplicate `StreamError` to the client.
|
||||
|
||||
Attributes:
|
||||
error_msg: The user-facing error message to persist.
|
||||
code: Machine-readable error code (e.g. ``circuit_breaker_empty_tool_calls``).
|
||||
retryable: Whether the frontend should offer a retry button.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_msg: str | None = None,
|
||||
code: str | None = None,
|
||||
retryable: bool = True,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.error_msg = error_msg
|
||||
self.code = code
|
||||
self.retryable = retryable
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EmptyToolBreakResult:
|
||||
"""Result of checking for empty tool calls in a single AssistantMessage."""
|
||||
|
||||
count: int # Updated consecutive counter
|
||||
tripped: bool # Whether the circuit breaker fired
|
||||
error: StreamError | None # StreamError to yield (if tripped)
|
||||
error_msg: str | None # Error message (if tripped)
|
||||
error_code: str | None # Error code (if tripped)
|
||||
|
||||
|
||||
def _check_empty_tool_breaker(
|
||||
sdk_msg: object,
|
||||
consecutive: int,
|
||||
ctx: _StreamContext,
|
||||
state: _RetryState,
|
||||
) -> _EmptyToolBreakResult:
|
||||
"""Detect consecutive empty tool calls and trip the circuit breaker.
|
||||
|
||||
Returns an ``_EmptyToolBreakResult`` with the updated counter and, if the
|
||||
breaker tripped, the ``StreamError`` to yield plus the error metadata.
|
||||
"""
|
||||
if not isinstance(sdk_msg, AssistantMessage):
|
||||
return _EmptyToolBreakResult(consecutive, False, None, None, None)
|
||||
|
||||
empty_tools = [
|
||||
b.name for b in sdk_msg.content if isinstance(b, ToolUseBlock) and not b.input
|
||||
]
|
||||
if not empty_tools:
|
||||
# Reset on any non-empty-tool AssistantMessage (including text-only
|
||||
# messages — any() over empty content is False).
|
||||
return _EmptyToolBreakResult(0, False, None, None, None)
|
||||
|
||||
consecutive += 1
|
||||
|
||||
# Log full diagnostics on first occurrence only; subsequent hits just
|
||||
# log the counter to reduce noise.
|
||||
if consecutive == 1:
|
||||
logger.warning(
|
||||
"%s Empty tool call detected (%d/%d): "
|
||||
"tools=%s, model=%s, error=%s, "
|
||||
"block_types=%s, cumulative_usage=%s",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
_EMPTY_TOOL_CALL_LIMIT,
|
||||
empty_tools,
|
||||
sdk_msg.model,
|
||||
sdk_msg.error,
|
||||
[type(b).__name__ for b in sdk_msg.content],
|
||||
{
|
||||
"prompt": state.usage.prompt_tokens,
|
||||
"completion": state.usage.completion_tokens,
|
||||
"cache_read": state.usage.cache_read_tokens,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s Empty tool call detected (%d/%d): tools=%s",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
_EMPTY_TOOL_CALL_LIMIT,
|
||||
empty_tools,
|
||||
)
|
||||
|
||||
if consecutive < _EMPTY_TOOL_CALL_LIMIT:
|
||||
return _EmptyToolBreakResult(consecutive, False, None, None, None)
|
||||
|
||||
logger.error(
|
||||
"%s Circuit breaker: aborting stream after %d "
|
||||
"consecutive empty tool calls. "
|
||||
"This is likely caused by the model attempting "
|
||||
"to write content too large for a single tool "
|
||||
"call's output token limit. The model should "
|
||||
"write large files in chunks using bash_exec "
|
||||
"with cat >> (append).",
|
||||
ctx.log_prefix,
|
||||
consecutive,
|
||||
)
|
||||
error_msg = _CIRCUIT_BREAKER_ERROR_MSG
|
||||
error_code = "circuit_breaker_empty_tool_calls"
|
||||
_append_error_marker(ctx.session, error_msg, retryable=True)
|
||||
return _EmptyToolBreakResult(
|
||||
count=consecutive,
|
||||
tripped=True,
|
||||
error=StreamError(errorText=error_msg, code=error_code),
|
||||
error_msg=error_msg,
|
||||
error_code=error_code,
|
||||
)
|
||||
|
||||
|
||||
async def _run_stream_attempt(
|
||||
ctx: _StreamContext,
|
||||
@@ -1039,8 +1207,24 @@ async def _run_stream_attempt(
|
||||
accumulated_tool_calls=[],
|
||||
)
|
||||
ended_with_stream_error = False
|
||||
# Stores the error message used by _append_error_marker so the outer
|
||||
# retry loop can re-append the correct message after session rollback.
|
||||
stream_error_msg: str | None = None
|
||||
stream_error_code: str | None = None
|
||||
|
||||
async with ClaudeSDKClient(options=state.options) as client:
|
||||
consecutive_empty_tool_calls = 0
|
||||
|
||||
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
|
||||
# suppress SDK cleanup errors that occur when the SSE client disconnects
|
||||
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
|
||||
# different async context/task than where the client was opened, which
|
||||
# triggers:
|
||||
# - ValueError: ContextVar token mismatch (AUTOGPT-SERVER-8BT)
|
||||
# - RuntimeError: cancel scope in wrong task (AUTOGPT-SERVER-8BW)
|
||||
# Both are harmless — the TCP connection is already dead.
|
||||
sdk_client = ClaudeSDKClient(options=state.options)
|
||||
client = await sdk_client.__aenter__()
|
||||
try:
|
||||
logger.info(
|
||||
"%s Sending query — resume=%s, total_msgs=%d, "
|
||||
"query_len=%d, attached_files=%d, image_blocks=%d",
|
||||
@@ -1129,18 +1313,43 @@ async def _run_stream_attempt(
|
||||
"suppressing raw error text",
|
||||
ctx.log_prefix,
|
||||
)
|
||||
stream_error_msg = FRIENDLY_TRANSIENT_MSG
|
||||
stream_error_code = "transient_api_error"
|
||||
_append_error_marker(
|
||||
ctx.session,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
stream_error_msg,
|
||||
retryable=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
errorText=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# Parallel tool execution: pre-launch every ToolUseBlock as an
|
||||
# asyncio.Task the moment its AssistantMessage arrives. The SDK
|
||||
# sends one AssistantMessage per tool call when issuing parallel
|
||||
# calls, so each message is pre-launched independently. The MCP
|
||||
# handlers will await the already-running task instead of executing
|
||||
# fresh, making all concurrent tool calls run in parallel.
|
||||
#
|
||||
# Also determine if the message is a tool-only batch (all content
|
||||
# items are ToolUseBlocks) — such messages have no text output yet,
|
||||
# so we skip the wait_for_stash flush below.
|
||||
is_tool_only = False
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_msg.content:
|
||||
is_tool_only = True
|
||||
# NOTE: Pre-launches are sequential (each await completes
|
||||
# file-ref expansion before the next starts). This is fine
|
||||
# since expansion is typically sub-ms; a future optimisation
|
||||
# could gather all pre-launches concurrently.
|
||||
for tool_use in sdk_msg.content:
|
||||
if isinstance(tool_use, ToolUseBlock):
|
||||
await pre_launch_tool_call(tool_use.name, tool_use.input)
|
||||
else:
|
||||
is_tool_only = False
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1154,15 +1363,12 @@ async def _run_stream_attempt(
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
is_parallel_continuation = isinstance(sdk_msg, AssistantMessage) and all(
|
||||
isinstance(b, ToolUseBlock) for b in sdk_msg.content
|
||||
)
|
||||
if (
|
||||
state.adapter.has_unresolved_tool_calls
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_parallel_continuation
|
||||
and not is_tool_only
|
||||
):
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
if await wait_for_stash():
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -1177,13 +1383,17 @@ async def _run_stream_attempt(
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
"(unresolved=%d, current=%d, resolved=%d, "
|
||||
"num_turns=%d, cost_usd=%s, result=%s)",
|
||||
ctx.log_prefix,
|
||||
sdk_msg.subtype,
|
||||
len(state.adapter.current_tool_calls)
|
||||
- len(state.adapter.resolved_tool_calls),
|
||||
len(state.adapter.current_tool_calls),
|
||||
len(state.adapter.resolved_tool_calls),
|
||||
sdk_msg.num_turns,
|
||||
sdk_msg.total_cost_usd,
|
||||
(sdk_msg.result or "")[:200],
|
||||
)
|
||||
if sdk_msg.subtype in (
|
||||
"error",
|
||||
@@ -1240,6 +1450,18 @@ async def _run_stream_attempt(
|
||||
)
|
||||
entries_replaced = True
|
||||
|
||||
# --- Hard circuit breaker for empty tool calls ---
|
||||
breaker = _check_empty_tool_breaker(
|
||||
sdk_msg, consecutive_empty_tool_calls, ctx, state
|
||||
)
|
||||
consecutive_empty_tool_calls = breaker.count
|
||||
if breaker.tripped and breaker.error is not None:
|
||||
stream_error_msg = breaker.error_msg
|
||||
stream_error_code = breaker.error_code
|
||||
yield breaker.error
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
# --- Dispatch adapter responses ---
|
||||
for response in state.adapter.convert_message(sdk_msg):
|
||||
dispatched = _dispatch_response(
|
||||
@@ -1262,6 +1484,8 @@ async def _run_stream_attempt(
|
||||
|
||||
if acc.stream_completed:
|
||||
break
|
||||
finally:
|
||||
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
|
||||
|
||||
# --- Post-stream processing (only on success) ---
|
||||
if state.adapter.has_unresolved_tool_calls:
|
||||
@@ -1320,8 +1544,10 @@ async def _run_stream_attempt(
|
||||
# to the client (StreamError yielded above), raise so the outer retry
|
||||
# loop can rollback session messages and set its error flags properly.
|
||||
if ended_with_stream_error:
|
||||
raise _TransientErrorHandled(
|
||||
"Transient API error handled — StreamError already yielded"
|
||||
raise _HandledStreamError(
|
||||
"Stream error handled — StreamError already yielded",
|
||||
error_msg=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -1332,6 +1558,7 @@ async def stream_chat_completion_sdk(
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1577,10 +1804,16 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
|
||||
set_execution_context(
|
||||
user_id,
|
||||
session,
|
||||
sandbox=e2b_sandbox,
|
||||
sdk_cwd=sdk_cwd,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
@@ -1603,8 +1836,11 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
if permissions is not None:
|
||||
allowed, disallowed = apply_tool_permissions(permissions, use_e2b=use_e2b)
|
||||
else:
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
"""Log a stderr line emitted by the Claude CLI subprocess."""
|
||||
@@ -1714,6 +1950,12 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
for attempt in range(_MAX_STREAM_ATTEMPTS):
|
||||
# Clear any stale stash signal from the previous attempt so
|
||||
# wait_for_stash() doesn't fire prematurely on a leftover event.
|
||||
reset_stash_event()
|
||||
# Reset tool-level circuit breaker so failures from a previous
|
||||
# (rolled-back) attempt don't carry over to the fresh attempt.
|
||||
reset_tool_failure_counters()
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
"%s Retrying with reduced context (%d/%d)",
|
||||
@@ -1769,6 +2011,10 @@ async def stream_chat_completion_sdk(
|
||||
if not isinstance(event, StreamHeartbeat):
|
||||
events_yielded += 1
|
||||
yield event
|
||||
# Cancel any pre-launched tasks that were never dispatched
|
||||
# by the SDK (e.g. edge-case SDK behaviour changes). Symmetric
|
||||
# with the three error-path await cancel_pending_tool_tasks() calls.
|
||||
await cancel_pending_tool_tasks()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -1777,26 +2023,42 @@ async def stream_chat_completion_sdk(
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
)
|
||||
# Cancel any pre-launched tasks so they don't continue executing
|
||||
# against a rolled-back or abandoned session.
|
||||
await cancel_pending_tool_tasks()
|
||||
raise
|
||||
except _TransientErrorHandled:
|
||||
except _HandledStreamError as exc:
|
||||
# _run_stream_attempt already yielded a StreamError and
|
||||
# appended an error marker. We only need to rollback
|
||||
# session messages and set the error flag — do NOT set
|
||||
# stream_err so the post-loop code won't emit a
|
||||
# duplicate StreamError.
|
||||
logger.warning(
|
||||
"%s Transient error handled in stream attempt "
|
||||
"(attempt %d/%d, events_yielded=%d)",
|
||||
"%s Stream error handled in attempt "
|
||||
"(attempt %d/%d, code=%s, events_yielded=%d)",
|
||||
log_prefix,
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
exc.code or "transient",
|
||||
events_yielded,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# transcript_builder still contains entries from the aborted
|
||||
# attempt that no longer match session.messages. Skip upload
|
||||
# so a future --resume doesn't replay rolled-back content.
|
||||
skip_transcript_upload = True
|
||||
# Re-append the error marker so it survives the rollback
|
||||
# and is persisted by the finally block (see #2947655365).
|
||||
_append_error_marker(session, FRIENDLY_TRANSIENT_MSG, retryable=True)
|
||||
# Use the specific error message from the attempt (e.g.
|
||||
# circuit breaker msg) rather than always the generic one.
|
||||
_append_error_marker(
|
||||
session,
|
||||
exc.error_msg or FRIENDLY_TRANSIENT_MSG,
|
||||
retryable=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
# Cancel any pre-launched tasks from the failed attempt.
|
||||
await cancel_pending_tool_tasks()
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
@@ -1813,6 +2075,9 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Cancel any pre-launched tasks from the failed attempt so they
|
||||
# don't continue executing against the rolled-back session.
|
||||
await cancel_pending_tool_tasks()
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
@@ -1822,11 +2087,13 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix,
|
||||
events_yielded,
|
||||
)
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
if not is_context_error:
|
||||
# Non-context errors (network, auth, rate-limit) should
|
||||
# not trigger compaction — surface the error immediately.
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
continue
|
||||
@@ -1922,6 +2189,16 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except GeneratorExit:
|
||||
# GeneratorExit is raised when the async generator is closed by the
|
||||
# caller (e.g. client disconnect, page refresh). We MUST release the
|
||||
# stream lock here because the ``finally`` block at the end of this
|
||||
# function may not execute when GeneratorExit propagates through nested
|
||||
# async generators. Without this, the lock stays held for its full TTL
|
||||
# and the user sees "Another stream is already active" on every retry.
|
||||
logger.warning("%s GeneratorExit — releasing stream lock", log_prefix)
|
||||
await lock.release()
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Catch BaseException to handle both Exception and CancelledError
|
||||
# (CancelledError inherits from BaseException in Python 3.8+)
|
||||
@@ -1930,9 +2207,16 @@ async def stream_chat_completion_sdk(
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
# SDK cleanup RuntimeError is expected during cancellation, log as warning
|
||||
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
|
||||
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
|
||||
# SDK cleanup errors are expected during client disconnect —
|
||||
# log as warning rather than error to reduce Sentry noise.
|
||||
# These are normally caught by _safe_close_sdk_client but
|
||||
# can escape in edge cases (e.g. GeneratorExit timing).
|
||||
if _is_sdk_disconnect_error(e):
|
||||
logger.warning(
|
||||
"%s SDK cleanup error (client disconnect): %s",
|
||||
log_prefix,
|
||||
error_msg,
|
||||
)
|
||||
else:
|
||||
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
|
||||
|
||||
@@ -1954,10 +2238,11 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Yield StreamError for immediate feedback (only for non-cancellation errors)
|
||||
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
|
||||
is_cancellation = isinstance(e, asyncio.CancelledError) or (
|
||||
isinstance(e, RuntimeError) and "cancel scope" in str(e)
|
||||
)
|
||||
# Skip for CancelledError and SDK disconnect cleanup errors — these
|
||||
# are not actionable by the user and the SSE connection is already dead.
|
||||
is_cancellation = isinstance(
|
||||
e, asyncio.CancelledError
|
||||
) or _is_sdk_disconnect_error(e)
|
||||
if not is_cancellation:
|
||||
yield StreamError(errorText=display_msg, code=code)
|
||||
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
``ReducedContext``, and the ``is_parallel_continuation`` logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
@@ -281,3 +283,55 @@ class TestIterSdkMessages:
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_parallel_continuation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsParallelContinuation:
|
||||
"""Unit tests for the is_parallel_continuation expression in the streaming loop.
|
||||
|
||||
Verifies the vacuous-truth guard (empty content must return False) and the
|
||||
boundary cases for mixed TextBlock+ToolUseBlock messages.
|
||||
"""
|
||||
|
||||
def _make_tool_block(self) -> MagicMock:
|
||||
block = MagicMock(spec=ToolUseBlock)
|
||||
return block
|
||||
|
||||
def test_all_tool_use_blocks_is_parallel(self):
|
||||
"""AssistantMessage with only ToolUseBlocks is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block(), self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
def test_empty_content_is_not_parallel(self):
|
||||
"""AssistantMessage with empty content must NOT be treated as parallel.
|
||||
|
||||
Without the bool(sdk_msg.content) guard, all() on an empty iterable
|
||||
returns True via vacuous truth — this test ensures the guard is present.
|
||||
"""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = []
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_mixed_text_and_tool_blocks_not_parallel(self):
|
||||
"""AssistantMessage with text + tool blocks is NOT a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
text_block = MagicMock(spec=TextBlock)
|
||||
msg.content = [text_block, self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is False
|
||||
|
||||
def test_non_assistant_message_not_parallel(self):
|
||||
"""Non-AssistantMessage types are never parallel continuations."""
|
||||
assert _is_tool_only_message("not a message") is False
|
||||
assert _is_tool_only_message(None) is False
|
||||
assert _is_tool_only_message(42) is False
|
||||
|
||||
def test_single_tool_block_is_parallel(self):
|
||||
"""Single ToolUseBlock AssistantMessage is a parallel continuation."""
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
from .service import (
|
||||
_is_sdk_disconnect_error,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_safe_close_sdk_client,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -499,3 +504,111 @@ class TestResolveSdkModel:
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsSdkDisconnectError:
|
||||
"""Tests for _is_sdk_disconnect_error — identifies expected SDK cleanup errors."""
|
||||
|
||||
def test_cancel_scope_runtime_error(self):
|
||||
"""RuntimeError about cancel scope in wrong task is a disconnect error."""
|
||||
exc = RuntimeError(
|
||||
"Attempted to exit cancel scope in a different task than it was entered in"
|
||||
)
|
||||
assert _is_sdk_disconnect_error(exc) is True
|
||||
|
||||
def test_context_var_value_error(self):
|
||||
"""ValueError about ContextVar token mismatch is a disconnect error."""
|
||||
exc = ValueError(
|
||||
"<Token var=<ContextVar name='current_context'>> "
|
||||
"was created in a different Context"
|
||||
)
|
||||
assert _is_sdk_disconnect_error(exc) is True
|
||||
|
||||
def test_unrelated_runtime_error(self):
|
||||
"""Unrelated RuntimeError should NOT be classified as disconnect error."""
|
||||
exc = RuntimeError("something else went wrong")
|
||||
assert _is_sdk_disconnect_error(exc) is False
|
||||
|
||||
def test_unrelated_value_error(self):
|
||||
"""Unrelated ValueError should NOT be classified as disconnect error."""
|
||||
exc = ValueError("invalid argument")
|
||||
assert _is_sdk_disconnect_error(exc) is False
|
||||
|
||||
def test_other_exception_types(self):
|
||||
"""Non-RuntimeError/ValueError should NOT be classified as disconnect error."""
|
||||
assert _is_sdk_disconnect_error(TypeError("bad type")) is False
|
||||
assert _is_sdk_disconnect_error(OSError("network down")) is False
|
||||
assert _is_sdk_disconnect_error(asyncio.CancelledError()) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _safe_close_sdk_client — suppress cleanup errors during disconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSafeCloseSdkClient:
|
||||
"""Tests for _safe_close_sdk_client — suppresses expected SDK cleanup errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clean_exit(self):
|
||||
"""Normal __aexit__ (no error) should succeed silently."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(return_value=None)
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
client.__aexit__.assert_awaited_once_with(None, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_scope_runtime_error_suppressed(self):
|
||||
"""RuntimeError from cancel scope mismatch should be suppressed."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(
|
||||
side_effect=RuntimeError(
|
||||
"Attempted to exit cancel scope in a different task"
|
||||
)
|
||||
)
|
||||
# Should NOT raise
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_var_value_error_suppressed(self):
|
||||
"""ValueError from ContextVar token mismatch should be suppressed."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"<Token var=<ContextVar name='current_context'>> "
|
||||
"was created in a different Context"
|
||||
)
|
||||
)
|
||||
# Should NOT raise
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_suppressed_with_error_log(self):
|
||||
"""Unexpected exceptions should be caught (not propagated) but logged at error."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=OSError("unexpected"))
|
||||
# Should NOT raise — unexpected errors are also suppressed to
|
||||
# avoid crashing the generator during teardown. Logged at error
|
||||
# level so Sentry captures them via its logging integration.
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrelated_runtime_error_propagates(self):
|
||||
"""Non-cancel-scope RuntimeError should propagate (not suppressed)."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=RuntimeError("something unrelated"))
|
||||
with pytest.raises(RuntimeError, match="something unrelated"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrelated_value_error_propagates(self):
|
||||
"""Non-disconnect ValueError should propagate (not suppressed)."""
|
||||
client = AsyncMock()
|
||||
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
|
||||
with pytest.raises(ValueError, match="invalid argument"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for the tool call circuit breaker in tool_adapter.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES,
|
||||
_check_circuit_breaker,
|
||||
_clear_tool_failures,
|
||||
_consecutive_tool_failures,
|
||||
_record_tool_failure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tracker():
|
||||
"""Reset the circuit breaker tracker for each test."""
|
||||
token = _consecutive_tool_failures.set({})
|
||||
yield
|
||||
_consecutive_tool_failures.reset(token)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
def test_no_trip_below_threshold(self):
|
||||
"""Circuit breaker should not trip before reaching the limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Still under the limit
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_trips_at_threshold(self):
|
||||
"""Circuit breaker should trip after reaching the failure limit."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
_record_tool_failure("write_file", args)
|
||||
# Now it should trip
|
||||
result = _check_circuit_breaker("write_file", args)
|
||||
assert result is not None
|
||||
assert "STOP" in result
|
||||
assert "write_file" in result
|
||||
|
||||
def test_different_args_tracked_separately(self):
|
||||
"""Different args should have separate failure counters."""
|
||||
args_a = {"file_path": "/tmp/a.txt"}
|
||||
args_b = {"file_path": "/tmp/b.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args_a)
|
||||
# args_a should trip
|
||||
assert _check_circuit_breaker("write_file", args_a) is not None
|
||||
# args_b should NOT trip
|
||||
assert _check_circuit_breaker("write_file", args_b) is None
|
||||
|
||||
def test_different_tools_tracked_separately(self):
|
||||
"""Different tools should have separate failure counters."""
|
||||
args = {"file_path": "/tmp/test.txt"}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("tool_a", args)
|
||||
# tool_a should trip
|
||||
assert _check_circuit_breaker("tool_a", args) is not None
|
||||
# tool_b with same args should NOT trip
|
||||
assert _check_circuit_breaker("tool_b", args) is None
|
||||
|
||||
def test_empty_args_tracked(self):
|
||||
"""Empty args ({}) — the exact failure pattern from the bug — should be tracked."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is not None
|
||||
|
||||
def test_clear_resets_counter(self):
|
||||
"""Clearing failures should reset the counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES):
|
||||
_record_tool_failure("write_file", args)
|
||||
_clear_tool_failures("write_file")
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_success_clears_failures(self):
|
||||
"""A successful call should reset the failure counter."""
|
||||
args = {}
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
# Success clears failures
|
||||
_clear_tool_failures("write_file")
|
||||
# Should be able to fail again without tripping
|
||||
for _ in range(_MAX_CONSECUTIVE_TOOL_FAILURES - 1):
|
||||
_record_tool_failure("write_file", args)
|
||||
assert _check_circuit_breaker("write_file", args) is None
|
||||
|
||||
def test_no_tracker_returns_none(self):
|
||||
"""If tracker is not initialized, circuit breaker should not trip."""
|
||||
_consecutive_tool_failures.set(None) # type: ignore[arg-type]
|
||||
_record_tool_failure("write_file", {}) # should not raise
|
||||
assert _check_circuit_breaker("write_file", {}) is None
|
||||
@@ -0,0 +1,822 @@
|
||||
"""Tests for thinking/redacted_thinking block preservation.
|
||||
|
||||
Validates the fix for the Anthropic API error:
|
||||
"thinking or redacted_thinking blocks in the latest assistant message
|
||||
cannot be modified. These blocks must remain as they were in the
|
||||
original response."
|
||||
|
||||
The API requires that thinking blocks in the LAST assistant message are
|
||||
preserved value-identical. Older assistant messages may have thinking blocks
|
||||
stripped entirely. This test suite covers:
|
||||
|
||||
1. _flatten_assistant_content — strips thinking from older messages
|
||||
2. compact_transcript — preserves last assistant's thinking blocks
|
||||
3. response_adapter — handles ThinkingBlock without error
|
||||
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THINKING_BLOCK = {
|
||||
"type": "thinking",
|
||||
"thinking": "Let me analyze the user's request carefully...",
|
||||
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
|
||||
}
|
||||
|
||||
REDACTED_THINKING_BLOCK = {
|
||||
"type": "redacted_thinking",
|
||||
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
|
||||
}
|
||||
|
||||
|
||||
def _make_thinking_transcript() -> str:
|
||||
"""Build a transcript with thinking blocks in multiple assistant turns.
|
||||
|
||||
Layout:
|
||||
User 1 → Assistant 1 (thinking + text + tool_use)
|
||||
User 2 (tool_result) → Assistant 2 (thinking + text)
|
||||
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
|
||||
"""
|
||||
return build_structured_transcript(
|
||||
[
|
||||
("user", "What files are in this project?"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I should list the files.",
|
||||
"signature": "sig_old_1",
|
||||
},
|
||||
{"type": "text", "text": "Let me check the files."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "list_files",
|
||||
"input": {"path": "/"},
|
||||
},
|
||||
],
|
||||
),
|
||||
("user", "Here are the files: a.py, b.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Good, I see two Python files.",
|
||||
"signature": "sig_old_2",
|
||||
},
|
||||
{"type": "text", "text": "I found a.py and b.py."},
|
||||
],
|
||||
),
|
||||
("user", "Tell me about a.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "a.py contains the main entry point."},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
|
||||
"""Extract the content blocks of the last assistant entry in a transcript."""
|
||||
last_content = None
|
||||
for line in transcript_jsonl.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_content = msg.get("content")
|
||||
return last_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_last_assistant_entry — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindLastAssistantEntry:
|
||||
def test_splits_at_last_assistant(self):
|
||||
"""Prefix contains everything before last assistant; tail starts at it."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
|
||||
assert len(prefix) == 3
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_no_assistant_returns_all_in_prefix(self):
|
||||
"""When there's no assistant, all lines are in prefix, tail is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("user", "Hello"), ("user", "Another question")]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 2
|
||||
assert tail == []
|
||||
|
||||
def test_assistant_at_index_zero(self):
|
||||
"""When assistant is the first entry, prefix is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("assistant", [{"type": "text", "text": "Start"}])]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert prefix == []
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_trailing_user_included_in_tail(self):
|
||||
"""User message after last assistant is part of the tail."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Q1"),
|
||||
("assistant", [{"type": "text", "text": "A1"}]),
|
||||
("user", "Q2"),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # first user
|
||||
assert len(tail) == 2 # last assistant + trailing user
|
||||
|
||||
def test_multi_entry_turn_fully_preserved(self):
|
||||
"""An assistant turn spanning multiple JSONL entries (same message.id)
|
||||
must be entirely in the tail, not split across prefix and tail."""
|
||||
# Build manually because build_structured_transcript generates unique ids
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [THINKING_BLOCK],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# Both assistant entries share msg_same_turn → both in tail
|
||||
assert len(prefix) == 1 # only the user entry
|
||||
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
|
||||
|
||||
def test_no_message_id_preserves_last_assistant(self):
|
||||
"""When the last assistant entry has no message.id, it should still
|
||||
be preserved in the tail (fail closed) rather than being compressed."""
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # user entry
|
||||
assert len(tail) == 1 # assistant entry preserved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rechain_tail — UUID chain patching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRechainTail:
|
||||
def test_patches_first_entry_parentuuid(self):
|
||||
"""First tail entry's parentUuid should point to last prefix uuid."""
|
||||
prefix = _messages_to_transcript(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
)
|
||||
# Get the last uuid from the prefix
|
||||
last_prefix_uuid = None
|
||||
for line in prefix.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
last_prefix_uuid = entry.get("uuid")
|
||||
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "tail-a1",
|
||||
"parentUuid": "old-parent",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Tail msg"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["parentUuid"] == last_prefix_uuid
|
||||
assert entry["uuid"] == "tail-a1" # uuid preserved
|
||||
|
||||
def test_chains_multiple_tail_entries(self):
|
||||
"""Subsequent tail entries chain to each other."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old1",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "t2",
|
||||
"parentUuid": "old2",
|
||||
"message": {"role": "user", "content": "Follow-up"},
|
||||
}
|
||||
),
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entries = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(entries) == 2
|
||||
# Second entry's parentUuid should be first entry's uuid
|
||||
assert entries[1]["parentUuid"] == "t1"
|
||||
|
||||
def test_empty_tail_returns_empty(self):
|
||||
"""No tail entries → empty string."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
assert _rechain_tail(prefix, []) == ""
|
||||
|
||||
def test_preserves_message_content_verbatim(self):
|
||||
"""Tail message content (including thinking blocks) must not be modified."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
original_content = [
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "Response"},
|
||||
]
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": original_content,
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == original_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content — thinking blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenThinkingBlocks:
|
||||
def test_thinking_blocks_are_stripped(self):
|
||||
"""Thinking blocks should not appear in flattened text for compression."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
|
||||
{"type": "text", "text": "Hello user"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "secret thoughts" not in result
|
||||
assert "Hello user" in result
|
||||
|
||||
def test_redacted_thinking_blocks_are_stripped(self):
|
||||
"""Redacted thinking blocks should not appear in flattened text."""
|
||||
blocks = [
|
||||
{"type": "redacted_thinking", "data": "encrypted_data"},
|
||||
{"type": "text", "text": "Response text"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "encrypted_data" not in result
|
||||
assert "Response text" in result
|
||||
|
||||
def test_thinking_only_message_flattens_to_empty(self):
|
||||
"""A message with only thinking blocks flattens to empty string."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
{"type": "text", "text": "I'll read the file."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript — thinking block preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscriptThinkingBlocks:
|
||||
"""Verify that compact_transcript preserves thinking blocks in the
|
||||
last assistant message while stripping them from older messages."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
|
||||
"""After compaction, the last assistant entry must retain its
|
||||
original thinking and redacted_thinking blocks verbatim."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None, "No assistant entry found"
|
||||
assert isinstance(last_content, list)
|
||||
|
||||
# The last assistant must have the thinking blocks preserved
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block missing from last assistant message"
|
||||
assert (
|
||||
"redacted_thinking" in block_types
|
||||
), "redacted_thinking block missing from last assistant message"
|
||||
assert "text" in block_types
|
||||
|
||||
# Verify the thinking block content is value-identical
|
||||
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
|
||||
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
|
||||
|
||||
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
|
||||
assert len(redacted_blocks) == 1
|
||||
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
|
||||
"""Older assistant messages should NOT retain thinking blocks
|
||||
after compaction (they're compressed into summaries)."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
# The compressor will receive messages where older assistant
|
||||
# entries have already had thinking blocks stripped.
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def mock_compression(messages, model, log_prefix):
|
||||
captured_messages.extend(messages)
|
||||
return type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": messages,
|
||||
"original_token_count": 800,
|
||||
"token_count": 400,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
|
||||
# Check that the messages sent to compression don't contain
|
||||
# thinking content from older assistant messages
|
||||
for msg in captured_messages:
|
||||
if msg["role"] == "assistant":
|
||||
content = msg.get("content", "")
|
||||
assert (
|
||||
"I should list the files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
assert (
|
||||
"Good, I see two Python files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
|
||||
"""When the last entry is a user message, the last *assistant*
|
||||
message's thinking blocks should still be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "Hi there"},
|
||||
],
|
||||
),
|
||||
("user", "Follow-up question"),
|
||||
]
|
||||
)
|
||||
|
||||
# The compressor only receives the prefix (1 user message); the
|
||||
# tail (assistant + trailing user) is preserved verbatim.
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 400,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block lost from last assistant despite trailing user msg"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
|
||||
"""When there's only one assistant message (which is also the last),
|
||||
its thinking blocks must be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "World"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "World"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert "thinking" in block_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
|
||||
"""After compaction, the first tail entry's parentUuid must point to
|
||||
the last entry in the compressed prefix — not its original parent."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
|
||||
entries = [json.loads(ln) for ln in lines]
|
||||
|
||||
# Find the boundary: the compressed prefix ends just before the
|
||||
# first tail entry (last assistant in original transcript).
|
||||
tail_start = None
|
||||
for i, entry in enumerate(entries):
|
||||
msg = entry.get("message", {})
|
||||
if isinstance(msg.get("content"), list):
|
||||
# Structured content = preserved tail entry
|
||||
tail_start = i
|
||||
break
|
||||
|
||||
assert tail_start is not None, "Could not find preserved tail entry"
|
||||
assert tail_start > 0, "Tail should not be the first entry"
|
||||
|
||||
# The tail entry's parentUuid must be the uuid of the preceding entry
|
||||
prefix_last_uuid = entries[tail_start - 1]["uuid"]
|
||||
tail_first_parent = entries[tail_start]["parentUuid"]
|
||||
assert tail_first_parent == prefix_last_uuid, (
|
||||
f"Tail parentUuid {tail_first_parent!r} != "
|
||||
f"last prefix uuid {prefix_last_uuid!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
|
||||
"""Compaction should still work normally when there are no thinking
|
||||
blocks in the transcript."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summary"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 50,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
# Verify last assistant content is preserved even without thinking blocks
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert last_content == [{"type": "text", "text": "Details"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages — thinking block handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscriptToMessagesThinking:
|
||||
def test_thinking_blocks_excluded_from_flattened_content(self):
|
||||
"""When _transcript_to_messages flattens content, thinking block
|
||||
text should not leak into the message content string."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "SECRET_THOUGHT",
|
||||
"signature": "sig",
|
||||
},
|
||||
{"type": "text", "text": "Visible response"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
|
||||
assert "SECRET_THOUGHT" not in assistant_msg["content"]
|
||||
assert "Visible response" in assistant_msg["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# response_adapter — ThinkingBlock handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseAdapterThinkingBlock:
|
||||
def test_thinking_block_does_not_crash(self):
|
||||
"""ThinkingBlock in AssistantMessage should not cause an error."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="Let me think about this...",
|
||||
signature="sig_test_123",
|
||||
),
|
||||
TextBlock(text="Here is my response."),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# Should produce stream events for text only, no crash
|
||||
types = [type(r) for r in results]
|
||||
assert StreamStartStep in types
|
||||
assert StreamTextStart in types or StreamTextDelta in types
|
||||
|
||||
def test_thinking_block_does_not_emit_stream_events(self):
|
||||
"""ThinkingBlock should NOT produce any StreamTextDelta events
|
||||
containing thinking content."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="My secret thoughts",
|
||||
signature="sig_test_456",
|
||||
),
|
||||
TextBlock(text="Public response"),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
for delta in text_deltas:
|
||||
assert "secret thoughts" not in (delta.delta or "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_sdk_content_blocks — redacted_thinking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatSdkContentBlocks:
|
||||
def test_thinking_block_preserved(self):
|
||||
"""ThinkingBlock should be serialized with type, thinking, and signature."""
|
||||
blocks = [
|
||||
ThinkingBlock(thinking="My thoughts", signature="sig123"),
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == {
|
||||
"type": "thinking",
|
||||
"thinking": "My thoughts",
|
||||
"signature": "sig123",
|
||||
}
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
|
||||
def test_raw_dict_redacted_thinking_preserved(self):
|
||||
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
|
||||
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
|
||||
blocks = [
|
||||
raw_block,
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == raw_block
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_permissions,
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
@@ -41,6 +42,8 @@ from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
@@ -50,6 +53,14 @@ _MCP_MAX_CHARS = 500_000
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Map from tool_name -> Queue of pre-launched (task, args) pairs.
|
||||
# Initialised per-session in set_execution_context() so concurrent sessions
|
||||
# never share the same dict.
|
||||
_TaskQueueItem = tuple[asyncio.Task[dict[str, Any]], dict[str, Any]]
|
||||
_tool_task_queues: ContextVar[dict[str, asyncio.Queue[_TaskQueueItem]] | None] = (
|
||||
ContextVar("_tool_task_queues", default=None)
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -66,12 +77,23 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Circuit breaker: tracks consecutive tool failures to detect infinite retry loops.
|
||||
# When a tool is called repeatedly with empty/identical args and keeps failing,
|
||||
# this counter is incremented. After _MAX_CONSECUTIVE_TOOL_FAILURES identical
|
||||
# failures the tool handler returns a hard-stop message instead of the raw error.
|
||||
_MAX_CONSECUTIVE_TOOL_FAILURES = 3
|
||||
_consecutive_tool_failures: ContextVar[dict[str, int]] = ContextVar(
|
||||
"_consecutive_tool_failures",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
@@ -83,14 +105,83 @@ def set_execution_context(
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_current_permissions.set(permissions)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_tool_task_queues.set({})
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def reset_stash_event() -> None:
|
||||
"""Clear any stale stash signal left over from a previous stream attempt.
|
||||
|
||||
``_stash_event`` is set once per session in ``set_execution_context`` and
|
||||
reused across retry attempts. A PostToolUse hook from a failed attempt may
|
||||
leave the event set; calling this at the start of each retry prevents
|
||||
``wait_for_stash`` from returning prematurely on a stale signal.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
|
||||
|
||||
async def cancel_pending_tool_tasks() -> None:
|
||||
"""Cancel all queued pre-launched tasks for the current execution context.
|
||||
|
||||
Call this when a stream attempt aborts (error, cancellation) to prevent
|
||||
pre-launched tasks from continuing to execute against a rolled-back session.
|
||||
Tasks that are already done are skipped; in-flight tasks are cancelled and
|
||||
awaited so that any cleanup (``finally`` blocks, DB rollbacks) completes
|
||||
before the next retry starts.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if not queues:
|
||||
return
|
||||
cancelled_tasks: list[asyncio.Task] = []
|
||||
for tool_name, queue in list(queues.items()):
|
||||
cancelled = 0
|
||||
while not queue.empty():
|
||||
task, _args = queue.get_nowait()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cancelled_tasks.append(task)
|
||||
cancelled += 1
|
||||
if cancelled:
|
||||
logger.debug(
|
||||
"Cancelled %d pre-launched task(s) for tool '%s'", cancelled, tool_name
|
||||
)
|
||||
queues.clear()
|
||||
# Await all cancelled tasks so their cleanup (finally blocks, DB rollbacks)
|
||||
# completes before the next retry attempt starts new pre-launches.
|
||||
# Use a timeout to prevent hanging indefinitely if a task's cleanup is stuck.
|
||||
if cancelled_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*cancelled_tasks, return_exceptions=True),
|
||||
timeout=5.0,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Timed out waiting for %d cancelled task(s) to clean up",
|
||||
len(cancelled_tasks),
|
||||
)
|
||||
|
||||
|
||||
def reset_tool_failure_counters() -> None:
|
||||
"""Reset all tool-level circuit breaker counters.
|
||||
|
||||
Called at the start of each SDK retry attempt so that failure counts
|
||||
from a previous (rolled-back) attempt do not carry over and prematurely
|
||||
trip the breaker on a fresh attempt with different context.
|
||||
"""
|
||||
_consecutive_tool_failures.set({})
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
@@ -155,12 +246,13 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
Uses ``asyncio.Event.wait()`` so it returns the instant the hook signals —
|
||||
the timeout is purely a safety net for the case where the hook never fires.
|
||||
Returns ``True`` if the stash signal was received, ``False`` on timeout.
|
||||
|
||||
The 2.0 s default was chosen based on production metrics: the original
|
||||
0.5 s caused frequent timeouts under load (parallel tool calls, large
|
||||
outputs). 2.0 s gives a comfortable margin while still failing fast
|
||||
when the hook genuinely will not fire.
|
||||
The 2.0 s default was chosen to accommodate slower tool startup in cloud
|
||||
sandboxes while still failing fast when the hook genuinely will not fire.
|
||||
With the parallel pre-launch path, hooks typically fire well under 1 ms.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -169,7 +261,7 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
# Slow path: block until the hook signals or the safety timeout expires.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
@@ -179,6 +271,82 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pre_launch_tool_call(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Pre-launch a tool as a background task so parallel calls run concurrently.
|
||||
|
||||
Called when an AssistantMessage with ToolUseBlocks is received, before the
|
||||
SDK dispatches the MCP tool/call requests. The tool_handler will await the
|
||||
pre-launched task instead of executing fresh.
|
||||
|
||||
The tool_name may include an MCP prefix (e.g. ``mcp__copilot__run_block``);
|
||||
the prefix is stripped automatically before looking up the tool.
|
||||
|
||||
Ordering guarantee: the Claude Agent SDK dispatches MCP ``tools/call`` requests
|
||||
in the same order as the ToolUseBlocks appear in the AssistantMessage.
|
||||
Pre-launched tasks are queued FIFO per tool name, so the N-th handler for a
|
||||
given tool name dequeues the N-th pre-launched task — result and args always
|
||||
correspond when the SDK preserves order (which it does in the current SDK).
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues is None:
|
||||
return
|
||||
|
||||
# Strip the MCP server prefix (e.g. "mcp__copilot__") to get the bare tool name.
|
||||
# Use removeprefix so tool names that themselves contain "__" are handled correctly.
|
||||
bare_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
base_tool = TOOL_REGISTRY.get(bare_name)
|
||||
if base_tool is None:
|
||||
return
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Expand @@agptfile: references before launching the task.
|
||||
# The _truncating wrapper (which normally handles expansion) runs AFTER
|
||||
# pre_launch_tool_call — the pre-launched task would otherwise receive raw
|
||||
# @@agptfile: tokens and fail to resolve them inside _execute_tool_sync.
|
||||
# Use _build_input_schema (same path as _truncating) for schema-aware expansion.
|
||||
input_schema: dict[str, Any] | None
|
||||
try:
|
||||
input_schema = _build_input_schema(base_tool)
|
||||
except Exception:
|
||||
input_schema = None # schema unavailable — skip schema-aware expansion
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
logger.warning(
|
||||
"pre_launch_tool_call: @@agptfile expansion failed for %s: %s — skipping pre-launch",
|
||||
bare_name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_execute_tool_sync(base_tool, user_id, session, args))
|
||||
# Log unhandled exceptions so "Task exception was never retrieved" warnings
|
||||
# do not pollute stderr when a task is pre-launched but never dequeued.
|
||||
task.add_done_callback(
|
||||
lambda t, name=bare_name: (
|
||||
logger.warning(
|
||||
"Pre-launched task for %s raised unhandled: %s",
|
||||
name,
|
||||
t.exception(),
|
||||
)
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if bare_name not in queues:
|
||||
queues[bare_name] = asyncio.Queue[_TaskQueueItem]()
|
||||
# Store (task, args) so the handler can log a warning if the SDK dispatches
|
||||
# calls in a different order than the ToolUseBlocks appeared in the message.
|
||||
queues[bare_name].put_nowait((task, args))
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
@@ -187,8 +355,10 @@ async def _execute_tool_sync(
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
Note: ``@@agptfile:`` expansion should be performed by the caller before
|
||||
invoking this function. For the normal (non-parallel) path it is handled
|
||||
by the ``_truncating`` wrapper; for the pre-launched parallel path it is
|
||||
handled in :func:`pre_launch_tool_call` before the task is created.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
@@ -217,6 +387,66 @@ def _mcp_error(message: str) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _failure_key(tool_name: str, args: dict[str, Any]) -> str:
|
||||
"""Compute a stable fingerprint for (tool_name, args) used by the circuit breaker."""
|
||||
args_key = json.dumps(args, sort_keys=True, default=str)
|
||||
return f"{tool_name}:{args_key}"
|
||||
|
||||
|
||||
def _check_circuit_breaker(tool_name: str, args: dict[str, Any]) -> str | None:
|
||||
"""Check if a tool has hit the consecutive failure limit.
|
||||
|
||||
Tracks failures keyed by (tool_name, args_fingerprint). Returns an error
|
||||
message if the circuit breaker has tripped, or None if the call should proceed.
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return None
|
||||
|
||||
key = _failure_key(tool_name, args)
|
||||
count = tracker.get(key, 0)
|
||||
if count >= _MAX_CONSECUTIVE_TOOL_FAILURES:
|
||||
logger.warning(
|
||||
"Circuit breaker tripped for tool %s after %d consecutive "
|
||||
"identical failures (args=%s)",
|
||||
tool_name,
|
||||
count,
|
||||
key[len(tool_name) + 1 :][:200],
|
||||
)
|
||||
return (
|
||||
f"STOP: Tool '{tool_name}' has failed {count} consecutive times with "
|
||||
f"the same arguments. Do NOT retry this tool call. "
|
||||
f"If you were trying to write content to a file, instead respond with "
|
||||
f"the content directly as a text message to the user."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _record_tool_failure(tool_name: str, args: dict[str, Any]) -> None:
|
||||
"""Record a tool failure for circuit breaker tracking."""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
key = _failure_key(tool_name, args)
|
||||
tracker[key] = tracker.get(key, 0) + 1
|
||||
|
||||
|
||||
def _clear_tool_failures(tool_name: str) -> None:
|
||||
"""Clear failure tracking for a tool on success.
|
||||
|
||||
Clears ALL args variants for the tool, not just the successful call's args.
|
||||
This gives the tool a "fresh start" on any success, which is appropriate for
|
||||
the primary use case (detecting infinite loops with identical failing args).
|
||||
"""
|
||||
tracker = _consecutive_tool_failures.get(None)
|
||||
if tracker is None:
|
||||
return
|
||||
# Clear all entries for this tool name
|
||||
keys_to_remove = [k for k in tracker if k.startswith(f"{tool_name}:")]
|
||||
for k in keys_to_remove:
|
||||
del tracker[k]
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
@@ -225,7 +455,83 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
"""Execute the wrapped tool and return MCP-formatted response.
|
||||
|
||||
If a pre-launched task exists (from parallel tool pre-launch in the
|
||||
message loop), await it instead of executing fresh.
|
||||
"""
|
||||
queues = _tool_task_queues.get()
|
||||
if queues and base_tool.name in queues:
|
||||
queue = queues[base_tool.name]
|
||||
if not queue.empty():
|
||||
task, launch_args = queue.get_nowait()
|
||||
# Sanity-check: warn if the args don't match — this can happen
|
||||
# if the SDK dispatches tool calls in a different order than the
|
||||
# ToolUseBlocks appeared in the AssistantMessage (unlikely but
|
||||
# could occur in future SDK versions or with SDK bugs).
|
||||
# We compare full values (not just keys) so that two run_block
|
||||
# calls with different block_id values are caught even though
|
||||
# both have the same key set.
|
||||
if launch_args != args:
|
||||
logger.warning(
|
||||
"Pre-launched task for %s: arg mismatch "
|
||||
"(launch_keys=%s, call_keys=%s) — cancelling "
|
||||
"pre-launched task and falling back to direct execution",
|
||||
base_tool.name,
|
||||
(
|
||||
sorted(launch_args.keys())
|
||||
if isinstance(launch_args, dict)
|
||||
else type(launch_args).__name__
|
||||
),
|
||||
(
|
||||
sorted(args.keys())
|
||||
if isinstance(args, dict)
|
||||
else type(args).__name__
|
||||
),
|
||||
)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Await cancellation to prevent duplicate concurrent
|
||||
# execution for blocks with side effects.
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Fall through to the direct-execution path below.
|
||||
else:
|
||||
# Args match — await the pre-launched task.
|
||||
try:
|
||||
result = await task
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise: CancelledError may be propagating from the
|
||||
# outer streaming loop being cancelled — swallowing it
|
||||
# would mask the cancellation and prevent proper cleanup.
|
||||
logger.warning(
|
||||
"Pre-launched tool %s was cancelled — re-raising",
|
||||
base_tool.name,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Pre-launched tool %s failed: %s",
|
||||
base_tool.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. "
|
||||
"Check server logs for details."
|
||||
)
|
||||
|
||||
# Pre-truncate the result so the _truncating wrapper (which
|
||||
# wraps this handler) receives an already-within-budget
|
||||
# value. _truncating handles stashing — we must NOT stash
|
||||
# here or the output will be appended twice to the FIFO
|
||||
# queue and pop_pending_tool_output would return a duplicate
|
||||
# entry on the second call for the same tool.
|
||||
return truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# No pre-launched task — execute directly (fallback for non-parallel calls).
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
@@ -234,8 +540,12 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(
|
||||
f"Failed to execute {base_tool.name}. Check server logs for details."
|
||||
)
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -358,6 +668,15 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# Circuit breaker: stop infinite retry loops with identical args.
|
||||
# Use the original (pre-expansion) args for fingerprinting so
|
||||
# check and record always use the same key — @@agptfile:
|
||||
# expansion mutates args, which would cause a key mismatch.
|
||||
original_args = args
|
||||
stop_msg = _check_circuit_breaker(tool_name, original_args)
|
||||
if stop_msg:
|
||||
return _mcp_error(stop_msg)
|
||||
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
@@ -365,6 +684,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
@@ -374,6 +694,12 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Track consecutive failures for circuit breaker
|
||||
if truncated.get("isError"):
|
||||
_record_tool_failure(tool_name, original_args)
|
||||
else:
|
||||
_clear_tool_failures(tool_name)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars, parallel pre-launch."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
cancel_pending_tool_tasks,
|
||||
create_tool_handler,
|
||||
pop_pending_tool_output,
|
||||
pre_launch_tool_call,
|
||||
reset_stash_event,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
wait_for_stash,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -120,6 +130,69 @@ class TestToolOutputStash:
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_stash_event / wait_for_stash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetStashEvent:
|
||||
"""Tests for reset_stash_event — the stale-signal fix for retry attempts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_stale_signal(self):
|
||||
"""After reset, wait_for_stash does NOT return immediately (blocks until timeout)."""
|
||||
# Simulate a stale signal left by a failed attempt's PostToolUse hook.
|
||||
stash_pending_tool_output("some_tool", "stale output")
|
||||
# The stash_pending_tool_output call sets the event.
|
||||
# Now reset it — simulating start of a new retry attempt.
|
||||
reset_stash_event()
|
||||
# wait_for_stash should block and time out since the event was cleared.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"wait_for_stash should have timed out after reset_stash_event, "
|
||||
"but it returned True — stale signal was not cleared"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_true_when_signaled_after_reset(self):
|
||||
"""After reset, a new stash signal is correctly detected."""
|
||||
reset_stash_event()
|
||||
|
||||
async def _signal_after_delay():
|
||||
await asyncio.sleep(0.01)
|
||||
stash_pending_tool_output("tool", "fresh output")
|
||||
|
||||
asyncio.create_task(_signal_after_delay())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_scenario_stale_event_does_not_fire_prematurely(self):
|
||||
"""Simulates: attempt 1 leaves event set → reset → attempt 2 waits correctly."""
|
||||
# Attempt 1: hook fires and sets the event
|
||||
stash_pending_tool_output("t", "attempt-1-output")
|
||||
# Pop it so the stash is empty (simulating normal consumption)
|
||||
pop_pending_tool_output("t")
|
||||
|
||||
# Between attempts: reset (as service.py does before each retry)
|
||||
reset_stash_event()
|
||||
|
||||
# Attempt 2: wait_for_stash should NOT return True immediately
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False, (
|
||||
"Stale event from attempt 1 caused wait_for_stash to return "
|
||||
"prematurely in attempt 2"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,3 +241,534 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel pre-launch infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
|
||||
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.parameters = {"properties": {}, "required": []}
|
||||
tool.execute = AsyncMock(
|
||||
return_value=StreamToolOutputAvailable(
|
||||
toolCallId="test-id",
|
||||
output=output,
|
||||
toolName=name,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _make_mock_session() -> MagicMock:
|
||||
"""Return a minimal ChatSession mock."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def _init_ctx(session=None):
|
||||
set_execution_context(
|
||||
user_id="user-1",
|
||||
session=session, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
|
||||
|
||||
class TestPreLaunchToolCall:
|
||||
"""Tests for pre_launch_tool_call and the queue-based parallel dispatch."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_is_silently_ignored(self):
|
||||
"""pre_launch_tool_call does nothing for tools not in TOOL_REGISTRY."""
|
||||
# Should not raise even if the tool name is completely unknown
|
||||
await pre_launch_tool_call("nonexistent_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_prefix_stripped_before_registry_lookup(self):
|
||||
"""mcp__copilot__run_block is looked up as 'run_block'."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("mcp__copilot__run_block", {"block_id": "b1"})
|
||||
|
||||
# The task was enqueued — mock_tool.execute should be called once
|
||||
# (may not complete immediately but should start)
|
||||
await asyncio.sleep(0) # yield to event loop
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_tool_name_without_prefix(self):
|
||||
"""Tool names without __ separator are looked up as-is."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_enqueued_fifo_for_same_tool(self):
|
||||
"""Two pre-launched calls for the same tool name are enqueued FIFO."""
|
||||
results = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
results.append(len(results))
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=str(len(results) - 1),
|
||||
toolName="t",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("t")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"t": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("t", {"n": 1})
|
||||
await pre_launch_tool_call("t", {"n": 2})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_ref_expansion_failure_skips_pre_launch(self):
|
||||
"""When @@agptfile: expansion fails, pre_launch_tool_call skips the task.
|
||||
|
||||
The handler should then fall back to direct execution (which will also
|
||||
fail with a proper MCP error via _truncating's own expansion).
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="should-not-execute")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=FileRefExpansionError("@@agptfile:missing.txt")),
|
||||
),
|
||||
):
|
||||
# Should not raise — expansion failure is handled gracefully
|
||||
await pre_launch_tool_call("run_block", {"text": "@@agptfile:missing.txt"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No task was pre-launched — execute was not called
|
||||
mock_tool.execute.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCreateToolHandlerParallel:
|
||||
"""Tests for create_tool_handler using pre-launched tasks."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_prelaunched_task(self):
|
||||
"""Handler pops and awaits the pre-launched task rather than re-executing."""
|
||||
mock_tool = _make_mock_tool("run_block", output="pre-launched result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "pre-launched result" in text
|
||||
# Should only have been called once (the pre-launched task), not twice
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_double_stash_for_prelaunched_task(self):
|
||||
"""Pre-launched task result must NOT be stashed by tool_handler directly.
|
||||
|
||||
The _truncating wrapper wraps tool_handler and handles stashing after
|
||||
tool_handler returns. If tool_handler also stashed, the output would be
|
||||
appended twice to the FIFO queue and pop_pending_tool_output would return
|
||||
a duplicate on the second call.
|
||||
|
||||
This test calls tool_handler directly (without _truncating) and asserts
|
||||
that nothing was stashed — confirming stashing is deferred to _truncating.
|
||||
"""
|
||||
mock_tool = _make_mock_tool("run_block", output="stash-me")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "stash-me" in result["content"][0]["text"]
|
||||
# tool_handler must NOT stash — _truncating (which wraps handler) does it.
|
||||
# Calling pop here (without going through _truncating) should return None.
|
||||
not_stashed = pop_pending_tool_output("run_block")
|
||||
assert not_stashed is None, (
|
||||
"tool_handler must not stash directly — _truncating handles stashing "
|
||||
"to prevent double-stash in the FIFO queue"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_falls_back_when_queue_empty(self):
|
||||
"""When no pre-launched task exists, handler executes directly."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct result")
|
||||
|
||||
# Don't call pre_launch_tool_call — queue is empty
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "direct result" in text
|
||||
mock_tool.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_cancelled_error_propagates(self):
|
||||
"""CancelledError from a pre-launched task is re-raised to preserve cancellation semantics."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler({"block_id": "b1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_exception_returns_mcp_error(self):
|
||||
"""Exception from a pre-launched task is caught and returned as MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=RuntimeError("block exploded"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "Failed to execute run_block" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_same_tool_calls_dispatched_in_order(self):
|
||||
"""Two pre-launched tasks for the same tool are consumed in FIFO order."""
|
||||
call_order = []
|
||||
|
||||
async def execute_with_tag(*args, **kwargs):
|
||||
tag = kwargs.get("block_id", "?")
|
||||
call_order.append(tag)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id", output=f"out-{tag}", toolName="run_block", success=True
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_tag)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "first"})
|
||||
await pre_launch_tool_call("run_block", {"block_id": "second"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"block_id": "first"})
|
||||
r2 = await handler({"block_id": "second"})
|
||||
|
||||
assert "out-first" in r1["content"][0]["text"]
|
||||
assert "out-second" in r2["content"][0]["text"]
|
||||
assert call_order == [
|
||||
"first",
|
||||
"second",
|
||||
], f"Expected FIFO dispatch order but got {call_order}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arg_mismatch_falls_back_to_direct_execution(self):
|
||||
"""When pre-launched args differ from SDK args, handler cancels pre-launched
|
||||
task and falls back to direct execution with the correct args."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-result")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
# Pre-launch with args {"block_id": "wrong"}
|
||||
await pre_launch_tool_call("run_block", {"block_id": "wrong"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# SDK dispatches with different args
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "correct"})
|
||||
|
||||
assert result["isError"] is False
|
||||
# The tool was called twice: once by pre-launch (wrong args), once by
|
||||
# direct fallback (correct args). The result should come from the
|
||||
# direct execution path.
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_falls_back_gracefully(self):
|
||||
"""When session is None and no pre-launched task, handler returns MCP error."""
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
# session=None means get_execution_context returns (user_id, None)
|
||||
set_execution_context(user_id="u", session=None, sandbox=None) # type: ignore[arg-type]
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is True
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_pending_tool_tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelPendingToolTasks:
|
||||
"""Tests for cancel_pending_tool_tasks — the stream-abort cleanup helper."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_queued_tasks(self):
|
||||
"""Queued tasks are cancelled and the queue is cleared."""
|
||||
ran = False
|
||||
|
||||
async def never_run(*_args, **_kwargs):
|
||||
nonlocal ran
|
||||
await asyncio.sleep(10) # long enough to still be pending
|
||||
ran = True
|
||||
|
||||
mock_tool = _make_mock_tool("run_block")
|
||||
mock_tool.execute = AsyncMock(side_effect=never_run)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0) # let task start
|
||||
await cancel_pending_tool_tasks()
|
||||
await asyncio.sleep(0) # let cancellation propagate
|
||||
|
||||
assert not ran, "Task should have been cancelled before completing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_tasks_queued(self):
|
||||
"""cancel_pending_tool_tasks does not raise when queues are empty."""
|
||||
await cancel_pending_tool_tasks() # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_does_not_find_cancelled_task(self):
|
||||
"""After cancel, tool_handler falls back to direct execution."""
|
||||
mock_tool = _make_mock_tool("run_block", output="direct-fallback")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"run_block": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("run_block", {"block_id": "b1"})
|
||||
await asyncio.sleep(0)
|
||||
await cancel_pending_tool_tasks()
|
||||
|
||||
# Queue is now empty — handler should execute directly
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({"block_id": "b1"})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "direct-fallback" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent / parallel pre-launch scenarios
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAllParallelToolsPrelaunchedIndependently:
|
||||
"""Simulate SDK sending N separate AssistantMessages for the same tool concurrently."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_parallel_tools_prelaunched_independently(self):
|
||||
"""5 pre-launches for the same tool all enqueue independently and run concurrently.
|
||||
|
||||
Each task sleeps for PER_TASK_S seconds. If they ran sequentially the total
|
||||
wall time would be ~5*PER_TASK_S. Running concurrently it should finish in
|
||||
roughly PER_TASK_S (plus scheduling overhead).
|
||||
"""
|
||||
PER_TASK_S = 0.05
|
||||
N = 5
|
||||
started: list[int] = []
|
||||
finished: list[int] = []
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
idx = len(started)
|
||||
started.append(idx)
|
||||
await asyncio.sleep(PER_TASK_S)
|
||||
finished.append(idx)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{idx}",
|
||||
output=f"result-{idx}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"echo {i}"})
|
||||
|
||||
# Measure only the concurrent execution window, not pre-launch overhead.
|
||||
# Starting the timer here avoids false failures on slow CI runners where
|
||||
# the pre_launch_tool_call setup takes longer than the concurrent sleep.
|
||||
t0 = asyncio.get_running_loop().time()
|
||||
await asyncio.sleep(PER_TASK_S * 2)
|
||||
elapsed = asyncio.get_running_loop().time() - t0
|
||||
|
||||
assert mock_tool.execute.await_count == N
|
||||
assert len(finished) == N
|
||||
# Wall time of the sleep window should be well under N * PER_TASK_S
|
||||
# (sequential would be ~0.25s; concurrent finishes in ~PER_TASK_S = 0.05s)
|
||||
assert elapsed < N * PER_TASK_S, (
|
||||
f"Expected concurrent execution (<{N * PER_TASK_S:.2f}s) "
|
||||
f"but sleep window took {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestHandlerReturnsResultFromCorrectPrelaunchedTask:
|
||||
"""Pop pre-launched tasks in order and verify each returns its own result."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_result_from_correct_prelaunched_task(self):
|
||||
"""Two pre-launches for the same tool: first handler gets first result, second gets second."""
|
||||
|
||||
async def execute_with_cmd(*args, **kwargs):
|
||||
cmd = kwargs.get("cmd", "?")
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="id",
|
||||
output=f"output-for-{cmd}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=execute_with_cmd)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "alpha"})
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": "beta"})
|
||||
await asyncio.sleep(0) # let both tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
r1 = await handler({"cmd": "alpha"})
|
||||
r2 = await handler({"cmd": "beta"})
|
||||
|
||||
text1 = r1["content"][0]["text"]
|
||||
text2 = r2["content"][0]["text"]
|
||||
assert "output-for-alpha" in text1, f"Expected alpha result, got: {text1}"
|
||||
assert "output-for-beta" in text2, f"Expected beta result, got: {text2}"
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestFiveConcurrentPrelaunchAllComplete:
|
||||
"""Pre-launch 5 tasks; consume all 5 via handlers; assert all succeed."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_five_concurrent_prelaunch_all_complete(self):
|
||||
"""All 5 pre-launched tasks complete and return successful results."""
|
||||
N = 5
|
||||
call_count = 0
|
||||
|
||||
async def counting_execute(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
n = call_count
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=f"id-{n}",
|
||||
output=f"done-{n}",
|
||||
toolName="bash_exec",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("bash_exec")
|
||||
mock_tool.execute = AsyncMock(side_effect=counting_execute)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.TOOL_REGISTRY",
|
||||
{"bash_exec": mock_tool},
|
||||
):
|
||||
for i in range(N):
|
||||
await pre_launch_tool_call("bash_exec", {"cmd": f"task-{i}"})
|
||||
|
||||
await asyncio.sleep(0) # let all tasks start
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
results = []
|
||||
for i in range(N):
|
||||
results.append(await handler({"cmd": f"task-{i}"}))
|
||||
|
||||
assert (
|
||||
mock_tool.execute.await_count == N
|
||||
), f"Expected {N} execute calls, got {mock_tool.execute.await_count}"
|
||||
for i, result in enumerate(results):
|
||||
assert result["isError"] is False, f"Result {i} should not be an error"
|
||||
text = result["content"][0]["text"]
|
||||
assert "done-" in text, f"Result {i} missing expected output: {text}"
|
||||
|
||||
@@ -605,20 +605,31 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
|
||||
silently dropped — they carry no useful context for compression
|
||||
summaries and must not leak into compacted transcripts (the Anthropic
|
||||
API requires thinking blocks in the last assistant message to be
|
||||
value-identical to the original response; including stale thinking
|
||||
text would violate that constraint).
|
||||
|
||||
This is intentional: ``compress_context`` requires plain text for
|
||||
token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript
|
||||
was already too large for the model.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype in _THINKING_BLOCK_TYPES:
|
||||
continue
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
@@ -805,6 +816,68 @@ async def _run_compression(
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_entry(
|
||||
content: str,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Split JSONL lines into (compressible_prefix, preserved_tail).
|
||||
|
||||
The tail starts at the **first** entry of the last assistant turn and
|
||||
includes everything after it (typically trailing user messages). An
|
||||
assistant turn can span multiple consecutive JSONL entries sharing the
|
||||
same ``message.id`` (e.g., a thinking entry followed by a tool_use
|
||||
entry). All entries of the turn are preserved verbatim.
|
||||
|
||||
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
|
||||
blocks in the **last** assistant message remain value-identical to the
|
||||
original response (the API validates parsed signature values, not raw
|
||||
JSON bytes). By excluding the entire turn from compression we
|
||||
guarantee those blocks are never altered.
|
||||
|
||||
Returns ``(all_lines, [])`` when no assistant entry is found.
|
||||
"""
|
||||
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
|
||||
|
||||
# Parse all lines once to avoid double JSON deserialization.
|
||||
# json.loads with fallback=None returns Any; non-dict entries are
|
||||
# safely skipped by the isinstance(entry, dict) guards below.
|
||||
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
|
||||
|
||||
# Reverse scan: find the message.id and index of the last assistant entry.
|
||||
last_asst_msg_id: str | None = None
|
||||
last_asst_idx: int | None = None
|
||||
for i in range(len(parsed) - 1, -1, -1):
|
||||
entry = parsed[i]
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_asst_idx = i
|
||||
last_asst_msg_id = msg.get("id")
|
||||
break
|
||||
|
||||
if last_asst_idx is None:
|
||||
return lines, []
|
||||
|
||||
# If the assistant entry has no message.id, fall back to preserving
|
||||
# from that single entry onward — safer than compressing everything.
|
||||
if last_asst_msg_id is None:
|
||||
return lines[:last_asst_idx], lines[last_asst_idx:]
|
||||
|
||||
# Forward scan: find the first entry of this turn (same message.id).
|
||||
first_turn_idx: int | None = None
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
|
||||
first_turn_idx = i
|
||||
break
|
||||
|
||||
if first_turn_idx is None:
|
||||
return lines, []
|
||||
return lines[:first_turn_idx], lines[first_turn_idx:]
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
@@ -816,42 +889,50 @@ async def compact_transcript(
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
The **last assistant entry** (and any entries after it) are preserved
|
||||
verbatim — never flattened or compressed. The Anthropic API requires
|
||||
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
|
||||
message to be value-identical to the original response (the API
|
||||
validates parsed signature values, not raw JSON bytes); compressing
|
||||
them would destroy the cryptographic signatures and cause
|
||||
``invalid_request_error``.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
Structured content in *older* assistant entries (``tool_use`` blocks,
|
||||
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
|
||||
plain text for compression. This matches the fidelity of the Plan C
|
||||
(DB compression) fallback path.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
lists for pre-query DB history.
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
prefix_lines, tail_lines = _find_last_assistant_entry(content)
|
||||
|
||||
# Build the JSONL string for the compressible prefix
|
||||
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
|
||||
messages = _transcript_to_messages(prefix_content) if prefix_content else []
|
||||
|
||||
if len(messages) + len(tail_lines) < 2:
|
||||
total = len(messages) + len(tail_lines)
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
|
||||
return None
|
||||
if not messages:
|
||||
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
if not result.messages:
|
||||
logger.warning("%s Compressor returned empty messages", log_prefix)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
@@ -860,7 +941,29 @@ async def compact_transcript(
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
compressed_part = _messages_to_transcript(result.messages)
|
||||
|
||||
# Re-append the preserved tail (last assistant + trailing entries)
|
||||
# with parentUuid patched to chain onto the compressed prefix.
|
||||
tail_part = _rechain_tail(compressed_part, tail_lines)
|
||||
compacted = compressed_part + tail_part
|
||||
|
||||
if len(compacted) >= len(content):
|
||||
# Byte count can increase due to preserved tail entries
|
||||
# (thinking blocks, JSON overhead) even when token count
|
||||
# decreased. Log a warning but still return — the API
|
||||
# validates tokens not bytes, and the caller falls through
|
||||
# to DB fallback if the transcript is still too large.
|
||||
logger.warning(
|
||||
"%s Compacted transcript (%d bytes) is not smaller than "
|
||||
"original (%d bytes) — may still reduce token count",
|
||||
log_prefix,
|
||||
len(compacted),
|
||||
len(content),
|
||||
)
|
||||
# Authoritative validation — the caller (_reduce_context) also
|
||||
# validates, but this is the canonical check that guarantees we
|
||||
# never return a malformed transcript from this function.
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
@@ -870,3 +973,43 @@ async def compact_transcript(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
|
||||
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
|
||||
|
||||
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
|
||||
last entry in the compressed prefix. Subsequent tail entries are
|
||||
rechained to point to their predecessor in the tail — their original
|
||||
``parentUuid`` values may reference entries that were compressed away.
|
||||
"""
|
||||
if not tail_lines:
|
||||
return ""
|
||||
# Find the last uuid in the compressed prefix
|
||||
last_prefix_uuid = ""
|
||||
for line in reversed(compressed_prefix.strip().split("\n")):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if isinstance(entry, dict) and "uuid" in entry:
|
||||
last_prefix_uuid = entry["uuid"]
|
||||
break
|
||||
|
||||
result_lines: list[str] = []
|
||||
prev_uuid: str | None = None
|
||||
for i, line in enumerate(tail_lines):
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
# Safety guard: _find_last_assistant_entry already filters empty
|
||||
# lines, and well-formed JSONL always parses to dicts. Non-dict
|
||||
# lines are passed through unchanged; prev_uuid is intentionally
|
||||
# NOT updated so the next dict entry chains to the last known uuid.
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if i == 0:
|
||||
entry["parentUuid"] = last_prefix_uuid
|
||||
elif prev_uuid is not None:
|
||||
entry["parentUuid"] = prev_uuid
|
||||
prev_uuid = entry.get("uuid")
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
@@ -17,13 +17,16 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
@@ -33,12 +36,21 @@ from backend.data.redis_client import get_redis_async
|
||||
from .config import ChatConfig
|
||||
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -100,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
|
||||
pre-dates the turn_id field (backward compat for in-flight sessions).
|
||||
"""
|
||||
created_at = datetime.now(timezone.utc)
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return ActiveSession(
|
||||
session_id=meta.get("session_id", "") or session_id,
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
@@ -108,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -280,6 +301,56 @@ async def publish_chunk(
|
||||
return message_id
|
||||
|
||||
|
||||
async def stream_and_publish(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
stream: AsyncIterator[StreamBaseResponse],
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Wrap an async stream iterator with registry publishing.
|
||||
|
||||
Publishes each chunk to the stream registry for frontend SSE consumption,
|
||||
skipping ``StreamFinish`` and ``StreamError`` (which are published by
|
||||
:func:`mark_session_completed`).
|
||||
|
||||
This is a pass-through: every event from *stream* is yielded unchanged so
|
||||
the caller can still consume and aggregate them. The caller is responsible
|
||||
for calling :func:`create_session` before and :func:`mark_session_completed`
|
||||
after iterating.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID (for logging only).
|
||||
turn_id: Turn UUID that identifies the Redis stream to publish to.
|
||||
If empty, publishing is silently skipped (graceful degradation).
|
||||
stream: The underlying async iterator of stream events.
|
||||
|
||||
Yields:
|
||||
Every event from *stream*, unchanged.
|
||||
"""
|
||||
publish_failed_once = False
|
||||
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s "
|
||||
"(further failures logged at DEBUG)",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[stream_and_publish] Failed to publish chunk %s",
|
||||
type(event).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
yield event
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
@@ -693,6 +764,8 @@ async def _stream_listener(
|
||||
async def mark_session_completed(
|
||||
session_id: str,
|
||||
error_message: str | None = None,
|
||||
*,
|
||||
skip_error_publish: bool = False,
|
||||
) -> bool:
|
||||
"""Mark a session as completed, then publish StreamFinish.
|
||||
|
||||
@@ -708,6 +781,10 @@ async def mark_session_completed(
|
||||
session_id: Session ID to mark as completed
|
||||
error_message: If provided, marks as "failed" and publishes a
|
||||
StreamError before StreamFinish. Otherwise marks as "completed".
|
||||
skip_error_publish: If True, still marks the session as "failed" but
|
||||
does NOT publish a StreamError event. Use this when the error has
|
||||
already been published to the stream (e.g. via stream_and_publish)
|
||||
to avoid duplicate error delivery to the frontend.
|
||||
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
@@ -727,7 +804,7 @@ async def mark_session_completed(
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
if error_message:
|
||||
if error_message and not skip_error_publish:
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
@@ -735,6 +812,33 @@ async def mark_session_completed(
|
||||
f"Failed to publish error event for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
# Compute wall-clock duration from session created_at.
|
||||
# Only persist when (a) the session completed successfully and
|
||||
# (b) created_at was actually present in Redis meta (not a fallback).
|
||||
duration_ms: int | None = None
|
||||
if meta and not error_message:
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
elapsed = datetime.now(timezone.utc) - created_at
|
||||
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to compute session duration for %s (created_at=%r)",
|
||||
session_id,
|
||||
created_at_raw,
|
||||
)
|
||||
|
||||
# Persist duration on the last assistant message
|
||||
if duration_ms is not None:
|
||||
try:
|
||||
await chat_db().set_turn_duration(session_id, duration_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
|
||||
|
||||
# Publish StreamFinish AFTER status is set to "completed"/"failed".
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
@@ -913,21 +1017,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextEnd,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
|
||||
@@ -102,7 +102,6 @@ async def setup_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Test input field",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
@@ -242,7 +241,6 @@ async def setup_llm_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "Prompt for the LLM",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
@@ -396,7 +394,6 @@ async def setup_firecrawl_test_data(server):
|
||||
"value": "",
|
||||
"advanced": False,
|
||||
"description": "URL for Firecrawl to scrape",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
metadata={"position": {"x": 0, "y": 0}},
|
||||
)
|
||||
|
||||
@@ -20,9 +20,9 @@ SSRF protection:
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time — skipped in Docker
|
||||
where system chromium is pre-installed and
|
||||
AGENT_BROWSER_EXECUTABLE_PATH is set)
|
||||
In Docker: system chromium package with AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
(set automatically — no `agent-browser install` needed).
|
||||
Locally: run `agent-browser install` to download Chromium.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
"""Integration tests for agent-browser + system chromium.
|
||||
|
||||
These tests actually invoke the agent-browser binary via subprocess and require:
|
||||
- agent-browser installed (npm install -g agent-browser)
|
||||
- AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium (set in Docker)
|
||||
|
||||
Run with:
|
||||
poetry run test
|
||||
|
||||
Or to run only this file:
|
||||
poetry run pytest backend/copilot/tools/agent_browser_integration_test.py -v -p no:autogpt_platform
|
||||
|
||||
Skipped automatically when agent-browser binary is not found.
|
||||
Tests that hit external sites are marked ``integration`` and skipped by default
|
||||
in CI (use ``-m integration`` to include them).
|
||||
|
||||
Two test tiers:
|
||||
- CLI tests: call agent-browser subprocess directly (no backend imports needed)
|
||||
- Tool class tests: call BrowserNavigateTool/BrowserActTool._execute() directly
|
||||
with user_id=None (skips workspace/DB interactions — no Postgres/RabbitMQ needed)
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.agent_browser import BrowserActTool, BrowserNavigateTool
|
||||
from backend.copilot.tools.models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("agent-browser") is None,
|
||||
reason="agent-browser binary not found",
|
||||
)
|
||||
|
||||
_SESSION = "integration-test-session"
|
||||
|
||||
|
||||
def _agent_browser(
|
||||
*args: str, session: str = _SESSION, timeout: int = 30
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session, return (rc, stdout, stderr)."""
|
||||
result = subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def _close_session(session: str, timeout: int = 5) -> None:
|
||||
"""Best-effort close for a browser session; never raises on failure."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["agent-browser", "--session", session, "--session-name", session, "close"],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _teardown():
|
||||
"""Close the shared test session after each test (best-effort)."""
|
||||
yield
|
||||
_close_session(_SESSION)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chromium_executable_env_is_set():
|
||||
"""AGENT_BROWSER_EXECUTABLE_PATH must be set and point to an executable binary."""
|
||||
exe = os.environ.get("AGENT_BROWSER_EXECUTABLE_PATH", "")
|
||||
assert exe, "AGENT_BROWSER_EXECUTABLE_PATH is not set"
|
||||
assert os.path.isfile(exe), f"Chromium binary not found at {exe}"
|
||||
assert os.access(exe, os.X_OK), f"Chromium binary at {exe} is not executable"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_navigate_returns_success():
|
||||
"""agent-browser can open a public URL using system chromium."""
|
||||
rc, _, stderr = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0, f"open failed (rc={rc}): {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_title_after_navigate():
|
||||
"""get title returns the page title after navigation."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "title", timeout=10)
|
||||
assert rc == 0, f"get title failed: {stderr}"
|
||||
assert "example" in stdout.lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_url_after_navigate():
|
||||
"""get url returns the navigated URL."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("get", "url", timeout=10)
|
||||
assert rc == 0, f"get url failed: {stderr}"
|
||||
assert urlparse(stdout.strip()).netloc == "example.com"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snapshot_returns_interactive_elements():
|
||||
"""snapshot -i -c lists interactive elements on the page."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, stdout, stderr = _agent_browser("snapshot", "-i", "-c", timeout=15)
|
||||
assert rc == 0, f"snapshot failed: {stderr}"
|
||||
assert len(stdout.strip()) > 0, "snapshot returned empty output"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_screenshot_produces_valid_png():
|
||||
"""screenshot saves a non-empty, valid PNG file."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
tmp = f.name
|
||||
try:
|
||||
rc, _, stderr = _agent_browser("screenshot", tmp, timeout=15)
|
||||
assert rc == 0, f"screenshot failed: {stderr}"
|
||||
size = os.path.getsize(tmp)
|
||||
assert size > 1000, f"PNG too small ({size} bytes) — likely blank or corrupt"
|
||||
with open(tmp, "rb") as f:
|
||||
assert f.read(4) == b"\x89PNG", "Output is not a valid PNG"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_scroll_down():
|
||||
"""scroll down succeeds without error."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("scroll", "down", timeout=10)
|
||||
assert rc == 0, f"scroll failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_fill_form_field():
|
||||
"""fill writes text into an input field."""
|
||||
rc, _, _ = _agent_browser("open", "https://httpbin.org/forms/post")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser(
|
||||
"fill", "input[name=custname]", "IntegrationTestUser", timeout=10
|
||||
)
|
||||
assert rc == 0, f"fill failed: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_concurrent_independent_sessions():
|
||||
"""Two independent sessions can navigate in parallel without interference."""
|
||||
session_a = "integration-concurrent-a"
|
||||
session_b = "integration-concurrent-b"
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
||||
fut_a = pool.submit(
|
||||
_agent_browser, "open", "https://example.com", session=session_a
|
||||
)
|
||||
fut_b = pool.submit(
|
||||
_agent_browser, "open", "https://httpbin.org/html", session=session_b
|
||||
)
|
||||
rc_a, _, err_a = fut_a.result(timeout=40)
|
||||
rc_b, _, err_b = fut_b.result(timeout=40)
|
||||
assert rc_a == 0, f"session_a open failed: {err_a}"
|
||||
assert rc_b == 0, f"session_b open failed: {err_b}"
|
||||
|
||||
rc_ua, url_a, err_ua = _agent_browser(
|
||||
"get", "url", session=session_a, timeout=10
|
||||
)
|
||||
rc_ub, url_b, err_ub = _agent_browser(
|
||||
"get", "url", session=session_b, timeout=10
|
||||
)
|
||||
assert rc_ua == 0, f"session_a get url failed: {err_ua}"
|
||||
assert rc_ub == 0, f"session_b get url failed: {err_ub}"
|
||||
assert urlparse(url_a.strip()).netloc == "example.com"
|
||||
assert urlparse(url_b.strip()).netloc == "httpbin.org"
|
||||
finally:
|
||||
_close_session(session_a)
|
||||
_close_session(session_b)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_close_session():
|
||||
"""close shuts down the browser daemon cleanly."""
|
||||
rc, _, _ = _agent_browser("open", "https://example.com")
|
||||
assert rc == 0
|
||||
|
||||
rc, _, stderr = _agent_browser("close", timeout=10)
|
||||
assert rc == 0, f"close failed: {stderr}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python tool class integration tests
|
||||
#
|
||||
# These tests exercise the actual BrowserNavigateTool / BrowserActTool Python
|
||||
# classes (not just the CLI binary) to verify the full call path — URL
|
||||
# validation, subprocess dispatch, response parsing — works with system
|
||||
# chromium. user_id=None skips workspace/DB interactions so no Postgres or
|
||||
# RabbitMQ is needed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOOL_SESSION_ID = "integration-tool-test-session"
|
||||
_TEST_SESSION = ChatSession(
|
||||
session_id=_TOOL_SESSION_ID,
|
||||
user_id="test-user",
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def _close_tool_session():
|
||||
"""Tear down the tool-test browser session after each tool test."""
|
||||
yield
|
||||
_close_session(_TOOL_SESSION_ID)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_returns_response(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns a BrowserNavigateResponse with real content."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserNavigateResponse
|
||||
), f"Expected BrowserNavigateResponse, got: {resp}"
|
||||
assert urlparse(resp.url).netloc == "example.com"
|
||||
assert resp.title, "Expected non-empty page title"
|
||||
assert resp.snapshot, "Expected non-empty accessibility snapshot"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ssrf_url",
|
||||
[
|
||||
"http://169.254.169.254/", # AWS/GCP/Azure metadata endpoint
|
||||
"http://127.0.0.1/", # IPv4 loopback
|
||||
"http://10.0.0.1/", # RFC-1918 private range
|
||||
"http://[::1]/", # IPv6 loopback
|
||||
"http://0.0.0.0/", # Wildcard / INADDR_ANY
|
||||
],
|
||||
)
|
||||
async def test_tool_navigate_blocked_url(ssrf_url: str, _close_tool_session):
|
||||
"""BrowserNavigateTool._execute rejects internal/private URLs (SSRF guard)."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url=ssrf_url)
|
||||
assert isinstance(
|
||||
resp, ErrorResponse
|
||||
), f"Expected ErrorResponse for SSRF URL {ssrf_url!r}, got: {resp}"
|
||||
assert resp.error == "blocked_url"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_navigate_missing_url(_close_tool_session):
|
||||
"""BrowserNavigateTool._execute returns an error when url is empty."""
|
||||
tool = BrowserNavigateTool()
|
||||
resp = await tool._execute(user_id=None, session=_TEST_SESSION, url="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_url"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_scroll(_close_tool_session):
|
||||
"""BrowserActTool._execute can scroll after a navigate."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://example.com"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="scroll", direction="down"
|
||||
)
|
||||
assert isinstance(
|
||||
resp, BrowserActResponse
|
||||
), f"Expected BrowserActResponse, got: {resp}"
|
||||
assert resp.action == "scroll"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_fill_and_click(_close_tool_session):
|
||||
"""BrowserActTool._execute can fill a form field."""
|
||||
nav = BrowserNavigateTool()
|
||||
nav_resp = await nav._execute(
|
||||
user_id=None, session=_TEST_SESSION, url="https://httpbin.org/forms/post"
|
||||
)
|
||||
assert isinstance(nav_resp, BrowserNavigateResponse)
|
||||
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None,
|
||||
session=_TEST_SESSION,
|
||||
action="fill",
|
||||
target="input[name=custname]",
|
||||
value="ToolIntegrationTest",
|
||||
)
|
||||
assert isinstance(resp, BrowserActResponse), f"fill failed: {resp}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_action(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when action is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(user_id=None, session=_TEST_SESSION, action="")
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_action"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_act_missing_target(_close_tool_session):
|
||||
"""BrowserActTool._execute returns an error when click target is missing."""
|
||||
act = BrowserActTool()
|
||||
resp = await act._execute(
|
||||
user_id=None, session=_TEST_SESSION, action="click", target=""
|
||||
)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert resp.error == "missing_target"
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
generate_uuid,
|
||||
@@ -31,7 +33,7 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Defaults applied to SmartDecisionMakerBlock nodes by the fixer.
|
||||
# Defaults applied to OrchestratorBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
@@ -1536,8 +1538,8 @@ class AgentFixer:
|
||||
for link in links:
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
if DICT_SPLIT in sink_name:
|
||||
parent, child = sink_name.split(DICT_SPLIT, 1)
|
||||
|
||||
# Check if child is a numeric index (invalid for _#_ notation)
|
||||
if child.isdigit():
|
||||
@@ -1639,8 +1641,8 @@ class AgentFixer:
|
||||
|
||||
return agent
|
||||
|
||||
def fix_smart_decision_maker_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix SmartDecisionMakerBlock nodes to ensure agent-mode defaults.
|
||||
def fix_orchestrator_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix OrchestratorBlock nodes to ensure agent-mode defaults.
|
||||
|
||||
Ensures:
|
||||
1. ``agent_mode_max_iterations`` defaults to ``10`` (bounded agent mode)
|
||||
@@ -1657,7 +1659,7 @@ class AgentFixer:
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
@@ -1670,7 +1672,7 @@ class AgentFixer:
|
||||
if field not in input_default or input_default[field] is None:
|
||||
input_default[field] = default_value
|
||||
self.add_fix_log(
|
||||
f"SmartDecisionMakerBlock {node_id}: "
|
||||
f"OrchestratorBlock {node_id}: "
|
||||
f"Set {field}={default_value!r}"
|
||||
)
|
||||
|
||||
@@ -1763,8 +1765,8 @@ class AgentFixer:
|
||||
# Apply fixes for MCPToolBlock nodes
|
||||
agent = self.fix_mcp_tool_blocks(agent)
|
||||
|
||||
# Apply fixes for SmartDecisionMakerBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_smart_decision_maker_blocks(agent)
|
||||
# Apply fixes for OrchestratorBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_orchestrator_blocks(agent)
|
||||
|
||||
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
|
||||
if library_agents:
|
||||
|
||||
@@ -4,6 +4,8 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
|
||||
__all__ = [
|
||||
@@ -12,7 +14,7 @@ __all__ = [
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"SMART_DECISION_MAKER_BLOCK_ID",
|
||||
"TOOL_ORCHESTRATOR_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
@@ -34,7 +36,7 @@ UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
SMART_DECISION_MAKER_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
@@ -51,8 +53,8 @@ def generate_uuid() -> str:
|
||||
|
||||
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
|
||||
"""Get property type from a schema, handling nested `_#_` notation."""
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
if DICT_SPLIT in name:
|
||||
parent, child = name.split(DICT_SPLIT, 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema and isinstance(
|
||||
parent_schema["properties"], dict
|
||||
|
||||
@@ -5,12 +5,14 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
TOOL_ORCHESTRATOR_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
@@ -256,95 +258,6 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
Returns True if all nested links are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not sink_name or not sink_id:
|
||||
continue
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id")
|
||||
input_props = block_input_schemas.get(block_id, {})
|
||||
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' for "
|
||||
f"node '{sink_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Parent property "
|
||||
f"'{parent}' does not exist in the block's "
|
||||
f"input schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
|
||||
# Check anyOf for additionalProperties
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' "
|
||||
f"for node '{link.get('sink_id', '')}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' schema. Available "
|
||||
f"properties: "
|
||||
f"{list(parent_schema.get('properties', {}).keys())}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that prompt parameters do not contain spaces in double curly
|
||||
@@ -471,8 +384,8 @@ class AgentValidator:
|
||||
output_props = block_output_schemas.get(block_id, {})
|
||||
|
||||
# Handle nested source names (with _#_ notation)
|
||||
if "_#_" in source_name:
|
||||
parent, child = source_name.split("_#_", 1)
|
||||
if DICT_SPLIT in source_name:
|
||||
parent, child = source_name.split(DICT_SPLIT, 1)
|
||||
|
||||
parent_schema = output_props.get(parent)
|
||||
if not parent_schema:
|
||||
@@ -553,6 +466,195 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_sink_input_existence(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all sink_names in links and input_default keys in nodes
|
||||
exist in the corresponding block's input schema.
|
||||
|
||||
Checks that for each link the sink_name references a valid input
|
||||
property in the sink block's inputSchema, and that every key in a
|
||||
node's input_default is a recognised input property. Also handles
|
||||
nested inputs with _#_ notation and dynamic schemas for
|
||||
AgentExecutorBlock.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
node_lookup: Optional pre-built node-id → node dict
|
||||
|
||||
Returns:
|
||||
True if all sink input fields exist, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
def get_input_props(node: dict[str, Any]) -> dict[str, Any]:
|
||||
block_id = node.get("block_id", "")
|
||||
if block_id == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
dynamic_input_schema = input_default.get("input_schema", {})
|
||||
if not isinstance(dynamic_input_schema, dict):
|
||||
dynamic_input_schema = {}
|
||||
dynamic_props = dynamic_input_schema.get("properties", {})
|
||||
if not isinstance(dynamic_props, dict):
|
||||
dynamic_props = {}
|
||||
static_props = block_input_schemas.get(block_id, {})
|
||||
return {**static_props, **dynamic_props}
|
||||
return block_input_schemas.get(block_id, {})
|
||||
|
||||
def check_nested_input(
|
||||
input_props: dict[str, Any],
|
||||
field_name: str,
|
||||
context: str,
|
||||
block_name: str,
|
||||
block_id: str,
|
||||
) -> bool:
|
||||
parent, child = field_name.split(DICT_SPLIT, 1)
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"{context}: Parent property '{parent}' does not "
|
||||
f"exist in block '{block_name}' ({block_id}) input "
|
||||
f"schema."
|
||||
)
|
||||
return False
|
||||
|
||||
allows_additional = parent_schema.get("additionalProperties", False)
|
||||
# Only anyOf is checked here because Pydantic's JSON schema
|
||||
# emits optional/union fields via anyOf. allOf and oneOf are
|
||||
# not currently used by any block's dict-typed inputs, so
|
||||
# false positives from them are not a concern in practice.
|
||||
if not allows_additional and "anyOf" in parent_schema:
|
||||
for schema_option in parent_schema.get("anyOf", []):
|
||||
if not isinstance(schema_option, dict):
|
||||
continue
|
||||
if schema_option.get("additionalProperties"):
|
||||
allows_additional = True
|
||||
break
|
||||
items_schema = schema_option.get("items")
|
||||
if isinstance(items_schema, dict) and items_schema.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional = True
|
||||
break
|
||||
|
||||
if not allows_additional:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
available = (
|
||||
list(parent_schema.get("properties", {}).keys())
|
||||
if isinstance(parent_schema, dict)
|
||||
else []
|
||||
)
|
||||
self.add_error(
|
||||
f"{context}: Child property '{child}' does not "
|
||||
f"exist in parent '{parent}' of block "
|
||||
f"'{block_name}' ({block_id}) input schema. "
|
||||
f"Available properties: {available}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name", "")
|
||||
link_id = link.get("id", "Unknown")
|
||||
|
||||
if not sink_name:
|
||||
# Missing sink_name is caught by validate_data_type_compatibility
|
||||
continue
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
# Already caught by validate_link_node_references
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id", "")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
input_props = get_input_props(sink_node)
|
||||
|
||||
context = (
|
||||
f"Invalid sink input field '{sink_name}' in link "
|
||||
f"'{link_id}' to node '{sink_id}'"
|
||||
)
|
||||
|
||||
if DICT_SPLIT in sink_name:
|
||||
if not check_nested_input(
|
||||
input_props, sink_name, context, block_name, block_id
|
||||
):
|
||||
valid = False
|
||||
else:
|
||||
if sink_name not in input_props:
|
||||
available_inputs = list(input_props.keys())
|
||||
self.add_error(
|
||||
f"{context} (block '{block_name}' - {block_id}): "
|
||||
f"Input property '{sink_name}' does not exist in "
|
||||
f"the block's input schema. "
|
||||
f"Available inputs: {available_inputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
node_id = node.get("id")
|
||||
block_id = node.get("block_id", "")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
if not isinstance(input_default, dict) or not input_default:
|
||||
continue
|
||||
|
||||
if (
|
||||
block_id not in block_input_schemas
|
||||
and block_id != AGENT_EXECUTOR_BLOCK_ID
|
||||
):
|
||||
continue
|
||||
|
||||
input_props = get_input_props(node)
|
||||
|
||||
for key in input_default:
|
||||
if key == "credentials":
|
||||
continue
|
||||
|
||||
context = (
|
||||
f"Node '{node_id}' (block '{block_name}' - {block_id}) "
|
||||
f"has unknown input_default key '{key}'"
|
||||
)
|
||||
|
||||
if DICT_SPLIT in key:
|
||||
if not check_nested_input(
|
||||
input_props, key, context, block_name, block_id
|
||||
):
|
||||
valid = False
|
||||
else:
|
||||
if key not in input_props:
|
||||
available_inputs = list(input_props.keys())
|
||||
self.add_error(
|
||||
f"{context} which does not exist in the "
|
||||
f"block's input schema. "
|
||||
f"Available inputs: {available_inputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_io_blocks(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that the agent has at least one AgentInputBlock and one
|
||||
@@ -827,18 +929,18 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_smart_decision_maker_blocks(
|
||||
def validate_orchestrator_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Validate that SmartDecisionMakerBlock nodes have downstream tools.
|
||||
"""Validate that OrchestratorBlock nodes have downstream tools.
|
||||
|
||||
Checks that each SmartDecisionMakerBlock node has at least one link
|
||||
Checks that each OrchestratorBlock node has at least one link
|
||||
with ``source_name == "tools"`` connecting to a downstream block.
|
||||
Without tools, the block has nothing to call and will error at runtime.
|
||||
|
||||
Returns True if all SmartDecisionMakerBlock nodes are valid.
|
||||
Returns True if all OrchestratorBlock nodes are valid.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
@@ -848,7 +950,7 @@ class AgentValidator:
|
||||
non_tool_block_ids = {AGENT_INPUT_BLOCK_ID, AGENT_OUTPUT_BLOCK_ID}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
if node.get("block_id") != TOOL_ORCHESTRATOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
@@ -863,7 +965,7 @@ class AgentValidator:
|
||||
max_iter = input_default.get("agent_mode_max_iterations")
|
||||
if max_iter is not None and not isinstance(max_iter, int):
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has non-integer "
|
||||
f"agent_mode_max_iterations={max_iter!r}. "
|
||||
f"This field must be an integer."
|
||||
@@ -871,7 +973,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter < -1:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has invalid "
|
||||
f"agent_mode_max_iterations={max_iter}. "
|
||||
f"Use -1 for infinite or a positive number for "
|
||||
@@ -880,7 +982,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter > 100:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations="
|
||||
f"{max_iter} which is unusually high. Values above "
|
||||
f"100 risk excessive cost and long execution times. "
|
||||
@@ -890,7 +992,7 @@ class AgentValidator:
|
||||
valid = False
|
||||
elif max_iter == 0:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations=0 "
|
||||
f"(traditional mode). The agent generator only supports "
|
||||
f"agent mode (set to -1 for infinite or a positive "
|
||||
@@ -908,7 +1010,7 @@ class AgentValidator:
|
||||
|
||||
if not has_tools:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"OrchestratorBlock node '{customized_name}' "
|
||||
f"({node_id}) has no downstream tool blocks connected. "
|
||||
f"Connect at least one block to its 'tools' output so "
|
||||
f"the AI has tools to call."
|
||||
@@ -998,14 +1100,14 @@ class AgentValidator:
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Sink input existence",
|
||||
self.validate_sink_input_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
self.validate_prompt_double_curly_braces_spaces(agent),
|
||||
@@ -1025,8 +1127,8 @@ class AgentValidator:
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
(
|
||||
"SmartDecisionMaker blocks",
|
||||
self.validate_smart_decision_maker_blocks(agent, node_lookup),
|
||||
"Orchestrator blocks",
|
||||
self.validate_orchestrator_blocks(agent, node_lookup),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -331,43 +331,6 @@ class TestValidatePromptDoubleCurlyBracesSpaces:
|
||||
assert any("spaces" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_nested_sink_links
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateNestedSinkLinks:
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is True
|
||||
|
||||
def test_invalid_parent_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(block_id="b1")
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_block_schemas
|
||||
# ============================================================================
|
||||
@@ -595,11 +558,28 @@ class TestValidate:
|
||||
input_block = _make_block(
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
name="AgentInputBlock",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
"description": {"type": "string"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
output_schema={"properties": {"result": {}}},
|
||||
)
|
||||
output_block = _make_block(
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
name="AgentOutputBlock",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"value": {},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
)
|
||||
input_node = _make_node(
|
||||
node_id="n-in",
|
||||
@@ -650,6 +630,201 @@ class TestValidate:
|
||||
assert "AgentOutputBlock" in error_message
|
||||
|
||||
|
||||
class TestValidateSinkInputExistence:
|
||||
"""Tests for validate_sink_input_existence."""
|
||||
|
||||
def test_valid_sink_name_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src", source_name="out", sink_id="n1", sink_name="url"
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_sink_name_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src", source_name="out", sink_id="n1", sink_name="nonexistent"
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("nonexistent" in e for e in v.errors)
|
||||
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="config_#_key",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_nested_child_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="config_#_missing",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
|
||||
def test_unknown_input_default_key_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1", block_id="b1", input_default={"nonexistent_key": "value"}
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("nonexistent_key" in e for e in v.errors)
|
||||
|
||||
def test_credentials_key_skipped(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={
|
||||
"url": "http://example.com",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_agent_executor_dynamic_schema_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_schema={
|
||||
"properties": {
|
||||
"graph_id": {"type": "string"},
|
||||
"input_schema": {"type": "object"},
|
||||
},
|
||||
"required": ["graph_id"],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": "abc",
|
||||
"input_schema": {
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="out",
|
||||
sink_id="n1",
|
||||
sink_name="query",
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
def test_input_default_nested_invalid_child_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={"config_#_invalid_child": "value"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is False
|
||||
assert any("invalid_child" in e for e in v.errors)
|
||||
|
||||
def test_input_default_nested_valid_child_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id="b1",
|
||||
input_default={"config_#_key": "value"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_sink_input_existence(agent, [block]) is True
|
||||
|
||||
|
||||
class TestValidateMCPToolBlocks:
|
||||
"""Tests for validate_mcp_tool_blocks."""
|
||||
|
||||
|
||||
@@ -10,7 +10,12 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionMeta,
|
||||
GraphExecutionWithNodes,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
|
||||
@@ -35,6 +40,7 @@ class AgentOutputInput(BaseModel):
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
wait_if_running: int = Field(default=0, ge=0, le=300)
|
||||
show_execution_details: bool = False
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
@@ -146,6 +152,10 @@ class AgentOutputTool(BaseTool):
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
"show_execution_details": {
|
||||
"type": "boolean",
|
||||
"description": "If true, include full node-by-node execution trace (inputs, outputs, status, timing for each node). Useful for debugging agent wiring. Default: false.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -226,13 +236,19 @@ class AgentOutputTool(BaseTool):
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
include_running: bool = False,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
include_node_executions: bool = False,
|
||||
) -> tuple[
|
||||
GraphExecution | GraphExecutionWithNodes | None,
|
||||
list[GraphExecutionMeta],
|
||||
str | None,
|
||||
]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
|
||||
Args:
|
||||
include_running: If True, also look for running/queued executions (for waiting)
|
||||
include_node_executions: If True, include node-by-node execution details
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
@@ -241,7 +257,7 @@ class AgentOutputTool(BaseTool):
|
||||
execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
include_node_executions=include_node_executions,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
@@ -279,7 +295,7 @@ class AgentOutputTool(BaseTool):
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
include_node_executions=include_node_executions,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
@@ -287,14 +303,14 @@ class AgentOutputTool(BaseTool):
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
include_node_executions=include_node_executions,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
execution: GraphExecution | GraphExecutionWithNodes | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
@@ -312,6 +328,21 @@ class AgentOutputTool(BaseTool):
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
node_executions_data = None
|
||||
if isinstance(execution, GraphExecutionWithNodes):
|
||||
node_executions_data = [
|
||||
{
|
||||
"node_id": ne.node_id,
|
||||
"block_id": ne.block_id,
|
||||
"status": ne.status.value,
|
||||
"input_data": ne.input_data,
|
||||
"output_data": dict(ne.output_data),
|
||||
"start_time": ne.start_time.isoformat() if ne.start_time else None,
|
||||
"end_time": ne.end_time.isoformat() if ne.end_time else None,
|
||||
}
|
||||
for ne in execution.node_executions
|
||||
]
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
@@ -319,6 +350,7 @@ class AgentOutputTool(BaseTool):
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
node_executions=node_executions_data,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
@@ -428,7 +460,7 @@ class AgentOutputTool(BaseTool):
|
||||
execution = await execution_db().get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
include_node_executions=input_data.show_execution_details,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
@@ -484,6 +516,7 @@ class AgentOutputTool(BaseTool):
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
include_running=wait_timeout > 0,
|
||||
include_node_executions=input_data.show_execution_details,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
|
||||
20
autogpt_platform/backend/backend/copilot/tools/conftest.py
Normal file
20
autogpt_platform/backend/backend/copilot/tools/conftest.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Local conftest for copilot/tools tests.
|
||||
|
||||
Overrides the session-scoped `server` and `graph_cleanup` autouse fixtures from
|
||||
backend/conftest.py so that integration tests in this directory do not trigger
|
||||
the full SpinTestServer startup (which requires Postgres + RabbitMQ).
|
||||
"""
|
||||
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def server(): # type: ignore[override]
|
||||
"""No-op server stub — tools tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(): # type: ignore[override]
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
@@ -119,8 +119,11 @@ class ContinueRunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Continuing block {block.name} ({block_id}) for user {user_id} "
|
||||
f"with review_id={review_id}"
|
||||
"Continuing block %s (%s) for user %s with review_id=%s",
|
||||
block.name,
|
||||
block_id,
|
||||
user_id,
|
||||
review_id,
|
||||
)
|
||||
|
||||
matched_creds, missing_creds = await resolve_block_credentials(
|
||||
@@ -132,6 +135,9 @@ class ContinueRunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# dry_run=False is safe here: run_block's dry-run fast-path (line ~241)
|
||||
# skips HITL entirely, so continue_run_block is never called during a
|
||||
# dry run — only real executions reach the human review gate.
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
@@ -140,6 +146,7 @@ class ContinueRunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
node_exec_id=review_id,
|
||||
matched_credentials=matched_creds,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Delete review record after successful execution (one-time use)
|
||||
|
||||
@@ -5,6 +5,7 @@ from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.context import get_current_permissions
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
@@ -38,7 +39,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# OrchestratorBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
}
|
||||
@@ -149,6 +150,19 @@ class FindBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check block-level permissions — hide denied blocks entirely
|
||||
perms = get_current_permissions()
|
||||
if perms is not None and not perms.is_block_allowed(
|
||||
block.id, block.name
|
||||
):
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
@@ -195,6 +209,7 @@ class FindBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
# Enrich results with block information
|
||||
perms = get_current_permissions()
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
@@ -211,6 +226,12 @@ class FindBlockTool(BaseTool):
|
||||
):
|
||||
continue
|
||||
|
||||
# Skip blocks denied by execution permissions
|
||||
if perms is not None and not perms.is_block_allowed(
|
||||
block.id, block.name
|
||||
):
|
||||
continue
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
|
||||
@@ -69,8 +69,8 @@ class TestFindBlockFiltering:
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
def test_excluded_block_ids_contains_orchestrator(self):
|
||||
"""Verify OrchestratorBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -120,18 +120,18 @@ class TestFindBlockFiltering:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
"""Verify OrchestratorBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
orchestrator_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": orchestrator_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
# OrchestratorBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
orchestrator_id, "Orchestrator", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
@@ -139,7 +139,7 @@ class TestFindBlockFiltering:
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
orchestrator_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestFindBlockFiltering:
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
# Should only return normal block, not OrchestratorBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -601,10 +601,8 @@ class TestFindBlockDirectLookup:
|
||||
async def test_uuid_lookup_excluded_block_id(self):
|
||||
"""UUID matching an excluded block ID returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
orchestrator_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(orchestrator_id, "Orchestrator", BlockType.STANDARD)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
@@ -612,7 +610,7 @@ class TestFindBlockDirectLookup:
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=smart_decision_id
|
||||
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
@@ -1,24 +1,46 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import BlockType, get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.db_accessors import credit_db, workspace_db
|
||||
from backend.data.db_accessors import credit_db, review_db, workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.simulator import simulate_block
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +81,7 @@ async def execute_block(
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
@@ -68,6 +91,49 @@ async def execute_block(
|
||||
Returns:
|
||||
BlockOutputResponse on success, ErrorResponse on failure.
|
||||
"""
|
||||
# Dry-run path: simulate the block with an LLM, no real execution.
|
||||
# HITL review is intentionally skipped — no real execution occurs.
|
||||
if dry_run:
|
||||
try:
|
||||
# Coerce types to match the block's input schema, same as real execution.
|
||||
# This ensures the simulated preview is consistent with real execution
|
||||
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in simulate_block(block, input_data):
|
||||
outputs[output_name].append(output_data)
|
||||
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
|
||||
sim_error = outputs.get("error", [])
|
||||
if (
|
||||
sim_error
|
||||
and isinstance(sim_error[0], str)
|
||||
and sim_error[0].startswith("[SIMULATOR ERROR")
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=sim_error[0],
|
||||
error=sim_error[0],
|
||||
session_id=session_id,
|
||||
)
|
||||
return BlockOutputResponse(
|
||||
message=(
|
||||
f"[DRY RUN] Block '{block.name}' simulated successfully "
|
||||
"— no real execution occurred."
|
||||
),
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
is_dry_run=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Dry-run simulation failed: %s", e, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Dry-run simulation failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
@@ -231,11 +297,301 @@ async def resolve_block_credentials(
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockPreparation:
|
||||
"""Result of successful block validation, ready for execution or task creation.
|
||||
|
||||
Attributes:
|
||||
block: The resolved block instance (schema definition + execute method).
|
||||
block_id: UUID of the block being prepared.
|
||||
input_data: User-supplied input values after file-ref expansion.
|
||||
matched_credentials: Credential field name -> resolved credential metadata.
|
||||
input_schema: JSON Schema for the block's input, with credential
|
||||
discriminators resolved for the user's available providers.
|
||||
credentials_fields: Set of field names in the schema that are credential
|
||||
inputs (e.g. ``{"credentials", "api_key"}``).
|
||||
required_non_credential_keys: Schema-required fields minus credential
|
||||
fields — the fields the user must supply directly.
|
||||
provided_input_keys: Keys the user actually provided in ``input_data``.
|
||||
synthetic_graph_id: Auto-generated graph UUID used for CoPilot
|
||||
single-block executions (no real graph exists in the DB).
|
||||
synthetic_node_id: Auto-generated node UUID paired with
|
||||
``synthetic_graph_id`` to form the execution context for the block.
|
||||
"""
|
||||
|
||||
block: AnyBlockSchema
|
||||
block_id: str
|
||||
input_data: dict[str, Any]
|
||||
matched_credentials: dict[str, CredentialsMetaInput]
|
||||
input_schema: dict[str, Any]
|
||||
credentials_fields: set[str]
|
||||
required_non_credential_keys: set[str]
|
||||
provided_input_keys: set[str]
|
||||
synthetic_graph_id: str
|
||||
synthetic_node_id: str
|
||||
|
||||
|
||||
async def prepare_block_for_execution(
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
session: ChatSession,
|
||||
session_id: str,
|
||||
dry_run: bool = False,
|
||||
) -> "BlockPreparation | ToolResponseBase":
|
||||
"""Validate and prepare a block for execution.
|
||||
|
||||
Performs: block lookup, disabled/excluded-type checks, credential resolution,
|
||||
input schema generation, file-ref expansion, missing-credentials check, and
|
||||
unrecognized-field validation.
|
||||
|
||||
Does NOT check for missing required fields (tools differ: run_block shows a
|
||||
schema preview) and does NOT run the HITL review check (use check_hitl_review
|
||||
separately).
|
||||
|
||||
Args:
|
||||
block_id: Block UUID to prepare.
|
||||
input_data: Input values provided by the caller.
|
||||
user_id: Authenticated user ID.
|
||||
session: Current chat session (needed for file-ref expansion).
|
||||
session_id: Chat session ID (used in error responses).
|
||||
|
||||
Returns:
|
||||
BlockPreparation on success, or a ToolResponseBase error/setup response.
|
||||
"""
|
||||
# Lazy import: find_block imports from .base and .models (siblings), not
|
||||
# from helpers — no actual circular dependency exists today. Kept lazy as a
|
||||
# precaution since find_block is the block-registry module and future changes
|
||||
# could introduce a cycle.
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found", session_id=session_id
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' is disabled", session_id=session_id
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
hint = (
|
||||
" Use the `run_mcp_tool` tool instead — it handles "
|
||||
"MCP server discovery, authentication, and execution."
|
||||
)
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
hint = " Use the `run_agent` tool instead."
|
||||
else:
|
||||
hint = " This block is designed for use within graphs only."
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' cannot be run directly.{hint}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
matched_credentials, missing_credentials = await resolve_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
|
||||
try:
|
||||
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate input schema for block %s: %s", block_id, e)
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' has an invalid input schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs using the block's input schema so string/list
|
||||
# fields get the correct deserialization.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data, user_id, session, input_schema=input_schema
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
if missing_credentials and not dry_run:
|
||||
credentials_fields_info = _resolve_discriminated_credentials(block, input_data)
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, set(matched_credentials.keys())
|
||||
)
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": get_inputs_from_schema(
|
||||
input_schema, exclude_fields=credentials_fields
|
||||
),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
required_keys = set(input_schema.get("required", []))
|
||||
required_non_credential_keys = required_keys - credentials_fields
|
||||
provided_input_keys = set(input_data.keys()) - credentials_fields
|
||||
|
||||
valid_fields = set(input_schema.get("properties", {}).keys()) - credentials_fields
|
||||
unrecognized_fields = provided_input_keys - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
"Block was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=input_schema,
|
||||
)
|
||||
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
return BlockPreparation(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
matched_credentials=matched_credentials,
|
||||
input_schema=input_schema,
|
||||
credentials_fields=credentials_fields,
|
||||
required_non_credential_keys=required_non_credential_keys,
|
||||
provided_input_keys=provided_input_keys,
|
||||
synthetic_graph_id=synthetic_graph_id,
|
||||
synthetic_node_id=synthetic_node_id,
|
||||
)
|
||||
|
||||
|
||||
async def check_hitl_review(
|
||||
prep: BlockPreparation,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> "tuple[str, dict[str, Any]] | ToolResponseBase":
|
||||
"""Check for an existing or new HITL review requirement.
|
||||
|
||||
If a review is needed, stores the review record and returns a
|
||||
ReviewRequiredResponse. Otherwise returns
|
||||
``(synthetic_node_exec_id, input_data)`` ready for execute_block.
|
||||
"""
|
||||
block = prep.block
|
||||
block_id = prep.block_id
|
||||
synthetic_graph_id = prep.synthetic_graph_id
|
||||
synthetic_node_id = prep.synthetic_node_id
|
||||
input_data = prep.input_data
|
||||
|
||||
# Reuse an existing WAITING review for identical input (LLM retry guard)
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
return synthetic_node_exec_id, input_data
|
||||
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
"""Resolve credential requirements, applying discriminator logic where needed.
|
||||
|
||||
Handles two discrimination modes:
|
||||
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
|
||||
field value selects the provider (e.g. an AI model name -> provider).
|
||||
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
|
||||
is ``None``): the discriminator field value (typically a URL) is added to
|
||||
``discriminator_values`` so that host-scoped credential matching can
|
||||
compare the credential's host against the target URL.
|
||||
"""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
@@ -245,23 +601,42 @@ def _resolve_discriminated_credentials(
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
if field_info.discriminator:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
if discriminator_value is not None:
|
||||
if field_info.discriminator_mapping:
|
||||
# Provider-based discrimination (e.g. model -> provider)
|
||||
if discriminator_value in field_info.discriminator_mapping:
|
||||
effective_field_info = field_info.discriminate(
|
||||
discriminator_value
|
||||
)
|
||||
effective_field_info.discriminator_values.add(
|
||||
discriminator_value
|
||||
)
|
||||
# Model names are safe to log (not PII); URLs are
|
||||
# intentionally omitted in the host-based branch below.
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
else:
|
||||
# URL/host-based discrimination (e.g. url -> host matching).
|
||||
# Deep copy to avoid mutating the cached schema-level
|
||||
# field_info (model_copy() is shallow — the mutable set
|
||||
# would be shared).
|
||||
effective_field_info = field_info.model_copy(deep=True)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Added discriminator value for host matching on %s",
|
||||
field_name,
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
"""Tests for execute_block, prepare_block_for_execution, and check_hitl_review."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
@@ -7,8 +7,20 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.tools.helpers import (
|
||||
BlockPreparation,
|
||||
check_hitl_review,
|
||||
execute_block,
|
||||
prepare_block_for_execution,
|
||||
)
|
||||
from backend.copilot.tools.models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupRequirementsResponse,
|
||||
)
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
@@ -510,3 +522,341 @@ async def test_coerce_inner_elements_of_generic():
|
||||
# Inner elements should be coerced from int to str
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_block_for_execution tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PREP_USER = "prep-user"
|
||||
_PREP_SESSION = "prep-session"
|
||||
|
||||
|
||||
def _make_prep_session(session_id: str = _PREP_SESSION) -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.session_id = session_id
|
||||
return session
|
||||
|
||||
|
||||
def _make_simple_block(
|
||||
block_id: str = "blk-1",
|
||||
name: str = "Simple Block",
|
||||
disabled: bool = False,
|
||||
required: list[str] | None = None,
|
||||
properties: dict[str, Any] | None = None,
|
||||
) -> MagicMock:
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = name
|
||||
block.disabled = disabled
|
||||
block.description = ""
|
||||
block.block_type = MagicMock()
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties or {"text": {"type": "string"}},
|
||||
"required": required or [],
|
||||
}
|
||||
block.input_schema.jsonschema.return_value = schema
|
||||
block.input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema.get_credentials_fields_info.return_value = {}
|
||||
return block
|
||||
|
||||
|
||||
def _patch_excluded(block_ids: set | None = None, block_types: set | None = None):
|
||||
return (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.COPILOT_EXCLUDED_BLOCK_IDS",
|
||||
new=block_ids or set(),
|
||||
create=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.COPILOT_EXCLUDED_BLOCK_TYPES",
|
||||
new=block_types or set(),
|
||||
create=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_not_found() -> None:
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=None),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="missing",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_disabled() -> None:
|
||||
block = _make_simple_block(disabled=True)
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "disabled" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_unrecognized_fields() -> None:
|
||||
block = _make_simple_block(properties={"text": {"type": "string"}})
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=lambda d, *a, **kw: d),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "hi", "unknown_field": "oops"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, InputValidationErrorResponse)
|
||||
assert "unknown_field" in result.unrecognized_fields
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_missing_credentials() -> None:
|
||||
block = _make_simple_block()
|
||||
mock_cred = MagicMock()
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [mock_cred])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.build_missing_credentials_from_field_info",
|
||||
return_value={"cred_key": mock_cred},
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_success_returns_preparation() -> None:
|
||||
block = _make_simple_block(
|
||||
required=["text"], properties={"text": {"type": "string"}}
|
||||
)
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(side_effect=lambda d, *a, **kw: d),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, BlockPreparation)
|
||||
assert result.required_non_credential_keys == {"text"}
|
||||
assert result.provided_input_keys == {"text"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_hitl_review tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hitl_prep(
|
||||
block_id: str = "blk-hitl",
|
||||
input_data: dict | None = None,
|
||||
session_id: str = "hitl-sess",
|
||||
needs_review: bool = False,
|
||||
) -> BlockPreparation:
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = "HITL Block"
|
||||
data = input_data if input_data is not None else {"action": "delete"}
|
||||
block.is_block_exec_need_review = AsyncMock(return_value=(needs_review, data))
|
||||
return BlockPreparation(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=data,
|
||||
matched_credentials={},
|
||||
input_schema={},
|
||||
credentials_fields=set(),
|
||||
required_non_credential_keys=set(),
|
||||
provided_input_keys=set(),
|
||||
synthetic_graph_id=f"{COPILOT_SESSION_PREFIX}{session_id}",
|
||||
synthetic_node_id=f"{COPILOT_NODE_PREFIX}{block_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_no_review_needed() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "read"}, needs_review=False)
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
node_exec_id, returned_data = result
|
||||
assert node_exec_id.startswith(f"{COPILOT_NODE_PREFIX}blk-hitl")
|
||||
assert returned_data == {"action": "read"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_review_required() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "delete"}, needs_review=True)
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, ReviewRequiredResponse)
|
||||
assert result.block_id == "blk-hitl"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_hitl_reuses_existing_waiting_review() -> None:
|
||||
prep = _make_hitl_prep(input_data={"action": "delete"}, needs_review=False)
|
||||
|
||||
existing = MagicMock()
|
||||
existing.node_id = prep.synthetic_node_id
|
||||
existing.status.value = "WAITING"
|
||||
existing.payload = {"action": "delete"}
|
||||
existing.node_exec_id = "existing-review-42"
|
||||
|
||||
mock_rdb = MagicMock()
|
||||
mock_rdb.get_pending_reviews_for_execution = AsyncMock(return_value=[existing])
|
||||
|
||||
with patch("backend.copilot.tools.helpers.review_db", return_value=mock_rdb):
|
||||
result = await check_hitl_review(prep, "user1", "hitl-sess")
|
||||
|
||||
assert isinstance(result, ReviewRequiredResponse)
|
||||
assert result.review_id == "existing-review-42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_excluded_by_type() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse for excluded block types."""
|
||||
from backend.blocks import BlockType
|
||||
|
||||
block = _make_simple_block()
|
||||
block.block_type = BlockType.AGENT
|
||||
|
||||
excl_ids, excl_types = _patch_excluded(block_types={BlockType.AGENT})
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-agent",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_excluded_by_id() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse for excluded block IDs."""
|
||||
block = _make_simple_block(block_id="blk-excluded")
|
||||
|
||||
excl_ids, excl_types = _patch_excluded(block_ids={"blk-excluded"})
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-excluded",
|
||||
input_data={},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "cannot be run directly" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_block_file_ref_expansion_error() -> None:
|
||||
"""prepare_block_for_execution returns ErrorResponse when file-ref expansion fails."""
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError
|
||||
|
||||
block = _make_simple_block(properties={"text": {"type": "string"}})
|
||||
excl_ids, excl_types = _patch_excluded()
|
||||
with (
|
||||
patch("backend.copilot.tools.helpers.get_block", return_value=block),
|
||||
excl_ids,
|
||||
excl_types,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.resolve_block_credentials",
|
||||
AsyncMock(return_value=({}, [])),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.expand_file_refs_in_args",
|
||||
AsyncMock(
|
||||
side_effect=FileRefExpansionError("@@agptfile:missing.txt not found")
|
||||
),
|
||||
),
|
||||
):
|
||||
result = await prepare_block_for_execution(
|
||||
block_id="blk-1",
|
||||
input_data={"text": "@@agptfile:missing.txt"},
|
||||
user_id=_PREP_USER,
|
||||
session=_make_prep_session(),
|
||||
session_id=_PREP_SESSION,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "file reference" in result.message.lower()
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
"""Tests for credential resolution across all credential types in the CoPilot.
|
||||
|
||||
These tests verify that:
|
||||
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
|
||||
for URL-based (host-scoped) and provider-based (api_key) credential fields.
|
||||
2. `find_matching_credential` correctly matches credentials for all types:
|
||||
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
|
||||
HostScopedCredentials.
|
||||
3. The full `resolve_block_credentials` flow correctly resolves matching
|
||||
credentials or reports them as missing for each credential type.
|
||||
4. `RunBlockTool._execute` end-to-end tests return correct response types.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import SendAuthenticatedWebRequestBlock
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
|
||||
from .models import BlockDetailsResponse, SetupRequirementsResponse
|
||||
from .run_block import RunBlockTool
|
||||
from .utils import find_matching_credential
|
||||
|
||||
_TEST_USER_ID = "test-user-http-cred"
|
||||
|
||||
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
|
||||
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
|
||||
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
|
||||
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_discriminated_credentials tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveDiscriminatedCredentials:
|
||||
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
|
||||
|
||||
def _get_auth_block(self):
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
def test_url_discriminator_populates_discriminator_values(self):
|
||||
"""When input_data contains a URL, discriminator_values should include it."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert "https://api.example.com/v1/data" in field_info.discriminator_values
|
||||
|
||||
def test_url_discriminator_without_url_keeps_empty_values(self):
|
||||
"""When no URL is provided, discriminator_values should remain empty."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_does_not_mutate_original_field_info(self):
|
||||
"""The original block schema field_info must not be mutated."""
|
||||
block = self._get_auth_block()
|
||||
|
||||
# Grab a reference to the original schema-level field_info
|
||||
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
|
||||
|
||||
# Call with a URL, which adds to discriminator_values on the copy
|
||||
_resolve_discriminated_credentials(
|
||||
block, {"url": "https://api.example.com/v1/data"}
|
||||
)
|
||||
|
||||
# The original object must remain unchanged
|
||||
assert len(original_info.discriminator_values) == 0
|
||||
|
||||
# And a fresh call without URL should also return empty values
|
||||
result = _resolve_discriminated_credentials(block, {})
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_preserves_provider_and_type(self):
|
||||
"""Provider and supported_types should be preserved after URL discrimination."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
field_info = result["credentials"]
|
||||
assert ProviderName.HTTP in field_info.provider
|
||||
assert "host_scoped" in field_info.supported_types
|
||||
|
||||
def test_provider_discriminator_still_works(self):
|
||||
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
|
||||
|
||||
The refactored conditional in _resolve_discriminated_credentials split the
|
||||
original single ``if`` into nested ``if/else`` branches. This test ensures
|
||||
the provider-based path still narrows the provider correctly.
|
||||
"""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
# Should narrow provider to openai
|
||||
assert ProviderName.OPENAI in field_info.provider
|
||||
assert "gpt-4o-mini" in field_info.discriminator_values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (host-scoped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingHostScopedCredential:
|
||||
"""Tests for find_matching_credential with host-scoped credentials."""
|
||||
|
||||
def _make_host_scoped_cred(
|
||||
self, host: str, cred_id: str = "test-cred-id"
|
||||
) -> HostScopedCredentials:
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider="http",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer test-token")},
|
||||
title=f"Cred for {host}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, discriminator_values: set | None = None
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.HTTP]),
|
||||
credentials_types=_HOST_SCOPED_TYPES,
|
||||
credentials_scopes=None,
|
||||
discriminator="url",
|
||||
discriminator_values=discriminator_values or set(),
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_host(self):
|
||||
"""A host-scoped credential matching the URL host should be returned."""
|
||||
cred = self._make_host_scoped_cred("api.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_host(self):
|
||||
"""A host-scoped credential for a different host should not match."""
|
||||
cred = self._make_host_scoped_cred("api.github.com")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_any_when_no_discriminator_values(self):
|
||||
"""With empty discriminator_values, any host-scoped credential matches.
|
||||
|
||||
Note: this tests the current fallback behavior in _credential_is_for_host()
|
||||
where empty discriminator_values means "no host constraint" and any
|
||||
host-scoped credential is accepted. This is by design for the case where
|
||||
the target URL is not yet known (e.g. schema preview with empty input).
|
||||
"""
|
||||
cred = self._make_host_scoped_cred("api.anything.com")
|
||||
field_info = self._make_field_info(set())
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_wildcard_host_matching(self):
|
||||
"""Wildcard host (*.example.com) should match subdomains."""
|
||||
cred = self._make_host_scoped_cred("*.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple host-scoped credentials exist, the correct one is selected."""
|
||||
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
|
||||
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred_github, cred_stripe], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "stripe-cred"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (api_key)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingAPIKeyCredential:
|
||||
"""Tests for find_matching_credential with API key credentials."""
|
||||
|
||||
def _make_api_key_cred(
|
||||
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
api_key=SecretStr("sk-test-key-123"),
|
||||
title=f"API key for {provider}",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An API key credential matching the provider should be returned."""
|
||||
cred = self._make_api_key_cred("google_maps")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An API key credential for a different provider should not match."""
|
||||
cred = self._make_api_key_cred("openai")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An OAuth2 credential should not match an api_key requirement."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-cred-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=[],
|
||||
title="OAuth cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple API key credentials exist, the correct provider is selected."""
|
||||
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
|
||||
cred_openai = self._make_api_key_cred("openai", "openai-key")
|
||||
field_info = self._make_field_info(ProviderName.OPENAI)
|
||||
|
||||
result = find_matching_credential([cred_maps, cred_openai], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "openai-key"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (oauth2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingOAuth2Credential:
|
||||
"""Tests for find_matching_credential with OAuth2 credentials."""
|
||||
|
||||
def _make_oauth2_cred(
|
||||
self,
|
||||
provider: str = "google",
|
||||
scopes: list[str] | None = None,
|
||||
cred_id: str = "test-oauth2-id",
|
||||
) -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=scopes or [],
|
||||
title=f"OAuth2 cred for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self,
|
||||
provider: ProviderName = ProviderName.GOOGLE,
|
||||
required_scopes: frozenset[str] | None = None,
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An OAuth2 credential matching the provider should be returned."""
|
||||
cred = self._make_oauth2_cred("google")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An OAuth2 credential for a different provider should not match."""
|
||||
cred = self._make_oauth2_cred("github")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_with_required_scopes(self):
|
||||
"""An OAuth2 credential with all required scopes should match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_rejects_credential_with_insufficient_scopes(self):
|
||||
"""An OAuth2 credential missing required scopes should not match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_when_no_scopes_required(self):
|
||||
"""An OAuth2 credential should match when no scopes are required."""
|
||||
cred = self._make_oauth2_cred("google", scopes=[])
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple OAuth2 credentials exist, the correct one is selected."""
|
||||
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
|
||||
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
|
||||
field_info = self._make_field_info(ProviderName.GITHUB)
|
||||
|
||||
result = find_matching_credential([cred_google, cred_github], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "github-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (user_password)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingUserPasswordCredential:
|
||||
"""Tests for find_matching_credential with user/password credentials."""
|
||||
|
||||
def _make_user_password_cred(
|
||||
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
|
||||
) -> UserPasswordCredentials:
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title=f"Credentials for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.SMTP
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""A user/password credential matching the provider should be returned."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""A user/password credential for a different provider should not match."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An API key credential should not match a user_password requirement."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="api-key-cred-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("wrong-type-key"),
|
||||
title="API key cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([api_key_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple user/password credentials exist, the correct one is selected."""
|
||||
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
|
||||
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "hubspot-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (mixed credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingCredentialMixedTypes:
|
||||
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
|
||||
|
||||
def test_selects_api_key_from_mixed_list(self):
|
||||
"""API key requirement should skip OAuth2 and user_password credentials."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="openai",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="openai",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.OPENAI]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[oauth_cred, userpass_cred, api_key_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "apikey-id"
|
||||
|
||||
def test_selects_oauth2_from_mixed_list(self):
|
||||
"""OAuth2 requirement should skip API key and user_password credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="google",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="google",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, userpass_cred, oauth_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "oauth-id"
|
||||
|
||||
def test_selects_user_password_from_mixed_list(self):
|
||||
"""User/password requirement should skip API key and OAuth2 credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="smtp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.SMTP]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, oauth_cred, userpass_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "userpass-id"
|
||||
|
||||
def test_returns_none_when_only_wrong_types_available(self):
|
||||
"""Should return None when all available creds have the wrong type."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_block_credentials tests (integration — all credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveBlockCredentials:
|
||||
"""Integration tests for resolve_block_credentials across credential types."""
|
||||
|
||||
async def test_matches_host_scoped_credential_for_url(self):
|
||||
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "matching-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_when_no_matching_host(self):
|
||||
"""resolve_block_credentials should report missing creds when host doesn't match."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.stripe.com/v1/charges"}
|
||||
|
||||
wrong_host_cred = HostScopedCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="http",
|
||||
host="api.github.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="GitHub API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_host_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_reports_missing_when_no_credentials(self):
|
||||
"""resolve_block_credentials should report missing when user has no creds at all."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_api_key_credential_for_llm_block(self):
|
||||
"""resolve_block_credentials should match an API key cred for an LLM block."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
mock_cred = APIKeyCredentials(
|
||||
id="openai-key-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-test-key"),
|
||||
title="OpenAI API Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "openai-key-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_api_key_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when API key provider mismatches."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
wrong_provider_cred = APIKeyCredentials(
|
||||
id="wrong-key-id",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr("sk-wrong"),
|
||||
title="Google Maps Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_provider_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_oauth2_credential_for_google_block(self):
|
||||
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
|
||||
block = GmailReadBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = OAuth2Credentials(
|
||||
id="google-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
access_token_expires_at=9999999999,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "google-oauth-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
|
||||
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
|
||||
from backend.blocks.google.gmail import GmailSendBlock
|
||||
|
||||
block = GmailSendBlock()
|
||||
input_data = {}
|
||||
|
||||
# GmailSendBlock requires gmail.send scope; provide only readonly
|
||||
insufficient_cred = OAuth2Credentials(
|
||||
id="limited-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth (limited)",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[insufficient_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_user_password_credential_for_email_block(self):
|
||||
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = UserPasswordCredentials(
|
||||
id="smtp-cred-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title="SMTP Credentials",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "smtp-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_user_password_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when user/password provider mismatches."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
wrong_cred = UserPasswordCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="dataforseo",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
title="DataForSEO Creds",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool integration tests for authenticated HTTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBlockToolAuthenticatedHttp:
|
||||
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
|
||||
|
||||
async def test_returns_setup_requirements_when_creds_missing(self):
|
||||
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={"url": "https://api.example.com/data", "method": "GET"},
|
||||
)
|
||||
|
||||
assert isinstance(response, SetupRequirementsResponse)
|
||||
assert "credentials" in response.message.lower()
|
||||
|
||||
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
|
||||
"""When creds present + required inputs missing -> BlockDetailsResponse.
|
||||
|
||||
Note: with input_data={}, no URL is provided so discriminator_values is
|
||||
empty, meaning _credential_is_for_host() matches any host-scoped
|
||||
credential vacuously. This test exercises the "creds present + inputs
|
||||
missing" branch, not host-based matching (which is covered by
|
||||
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
# Call with empty input to get schema
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.name == block.name
|
||||
# The matched credential should be included in the details
|
||||
assert len(response.block.credentials) > 0
|
||||
assert response.block.credentials[0].id == "matching-cred-id"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user