mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
124 Commits
fix/copilo
...
codex/plat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cc72e7608 | ||
|
|
fa0650214d | ||
|
|
3e016508d4 | ||
|
|
b80d7abda9 | ||
|
|
0e310c788a | ||
|
|
91af007c18 | ||
|
|
e7ca81ed89 | ||
|
|
5164fa878f | ||
|
|
cf605ef5a3 | ||
|
|
e7bd05c6f1 | ||
|
|
22fb3549e3 | ||
|
|
1c3fe1444e | ||
|
|
b89321a688 | ||
|
|
67bdef13e7 | ||
|
|
e67dd93ee8 | ||
|
|
3140a60816 | ||
|
|
41c2ee9f83 | ||
|
|
630d6d4705 | ||
|
|
7c685c6677 | ||
|
|
ca748ee12a | ||
|
|
bbdf13c7a8 | ||
|
|
e1ea4cf326 | ||
|
|
db6b4444e0 | ||
|
|
9b1175473b | ||
|
|
752a238166 | ||
|
|
2a73d1baa9 | ||
|
|
254e6057f4 | ||
|
|
a616e5a060 | ||
|
|
c9461836c6 | ||
|
|
50a8df3d67 | ||
|
|
243b12778f | ||
|
|
3f7a8dc44d | ||
|
|
1c15d6a6cc | ||
|
|
a31be77408 | ||
|
|
1d45f2f18c | ||
|
|
27e34e9514 | ||
|
|
16d696edcc | ||
|
|
f87bbd5966 | ||
|
|
b64d1ed9fa | ||
|
|
43c81910ae | ||
|
|
3895d95826 | ||
|
|
181208528f | ||
|
|
0365a26c85 | ||
|
|
fb63ae54f0 | ||
|
|
6de79fb73f | ||
|
|
a11199aa67 | ||
|
|
5f82a71d5f | ||
|
|
d57da6c078 | ||
|
|
689cd67a13 | ||
|
|
dca89d1586 | ||
|
|
2f63fcd383 | ||
|
|
f04cd08e40 | ||
|
|
44714f1b25 | ||
|
|
78b95f8a76 | ||
|
|
6f0c1dfa11 | ||
|
|
5e595231da | ||
|
|
7b36bed8a5 | ||
|
|
372900c141 | ||
|
|
1a305db162 | ||
|
|
7afd2b249d | ||
|
|
8d22653810 | ||
|
|
48a653dc63 | ||
|
|
f6ddcbc6cb | ||
|
|
b00e16b438 | ||
|
|
b5acfb7855 | ||
|
|
1ee0bd6619 | ||
|
|
98f13a6e5d | ||
|
|
613978a611 | ||
|
|
2b0e8a5a9f | ||
|
|
08bb05141c | ||
|
|
4190f75b0b | ||
|
|
71315aa982 | ||
|
|
3ccaa5e103 | ||
|
|
960f893295 | ||
|
|
09e42041ce | ||
|
|
759effab60 | ||
|
|
a50e95f210 | ||
|
|
92b395d82a | ||
|
|
45b6ada739 | ||
|
|
86abfbd394 | ||
|
|
a7f4093424 | ||
|
|
e33b1e2105 | ||
|
|
fff101e037 | ||
|
|
da544d3411 | ||
|
|
54e5059d7c | ||
|
|
1d7d2f77f3 | ||
|
|
567bc73ec4 | ||
|
|
61ef54af05 | ||
|
|
405403e6b7 | ||
|
|
f1ac05b2e0 | ||
|
|
f115607779 | ||
|
|
1aef8b7155 | ||
|
|
ab16e63b0a | ||
|
|
45d3193727 | ||
|
|
9a08011d7d | ||
|
|
6fa66ac7da | ||
|
|
4bad08394c | ||
|
|
993c43b623 | ||
|
|
a8a62eeefc | ||
|
|
173614bcc5 | ||
|
|
fbe634fb19 | ||
|
|
a338c72c42 | ||
|
|
7f4398efa3 | ||
|
|
c2a054c511 | ||
|
|
83b00f4789 | ||
|
|
0da949ba42 | ||
|
|
95524e94b3 | ||
|
|
2c517ff9a1 | ||
|
|
7020ae2189 | ||
|
|
b9336984be | ||
|
|
9924dedddc | ||
|
|
c054799b4f | ||
|
|
f3b5d584a3 | ||
|
|
476d9dcf80 | ||
|
|
072b623f8b | ||
|
|
26b0c95936 | ||
|
|
308357de84 | ||
|
|
1a6c50c6cc | ||
|
|
6b031085bd | ||
|
|
11b846dd49 | ||
|
|
b9e29c96bd | ||
|
|
4ac0ba570a | ||
|
|
d61a2c6cd0 | ||
|
|
1c301b4b61 |
10
.claude/settings.json
Normal file
10
.claude/settings.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allowedTools": [
|
||||
"Read", "Grep", "Glob",
|
||||
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
|
||||
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
|
||||
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,28 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
|
||||
|
||||
### Running coverage locally
|
||||
|
||||
**Backend** (from `autogpt_platform/backend/`):
|
||||
```bash
|
||||
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
|
||||
```
|
||||
|
||||
**Frontend** (from `autogpt_platform/frontend/`):
|
||||
```bash
|
||||
pnpm vitest run --coverage
|
||||
```
|
||||
|
||||
### When codecov/patch fails
|
||||
|
||||
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
|
||||
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
|
||||
3. Run coverage locally to verify, commit, push.
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
@@ -530,9 +530,19 @@ After showing all screenshots, output a **detailed** summary table:
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
# Each explanation MUST answer three things:
|
||||
# 1. FLOW: Which test scenario / user journey is this part of?
|
||||
# 2. STEPS: What exact actions were taken to reach this state?
|
||||
# 3. EVIDENCE: What does this screenshot prove (pass/fail/data)?
|
||||
#
|
||||
# Good example:
|
||||
# ["03-cost-log-after-run.png"]="Flow: LLM block cost tracking. Steps: Logged in as tester@gmail.com → ran 'Cost Test Agent' → waited for COMPLETED status. Evidence: PlatformCostLog table shows 1 new row with cost_microdollars=1234 and correct user_id."
|
||||
#
|
||||
# Bad example (too vague — never do this):
|
||||
# ["03-cost-log.png"]="Shows the cost log table."
|
||||
["01-login-page.png"]="Flow: Login flow. Steps: Opened /login. Evidence: Login page renders with email/password fields and SSO options visible."
|
||||
["02-builder-with-block.png"]="Flow: Block execution. Steps: Logged in → /build → added LLM block. Evidence: Builder canvas shows block connected to trigger, ready to run."
|
||||
# ... one entry per screenshot using the flow/steps/evidence format above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
@@ -547,6 +557,9 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
> **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.**
|
||||
> Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
@@ -582,12 +595,25 @@ for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
# Step 2: Create tree, commit (with parent), and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Resolve existing branch tip as parent (avoids orphan commits on repeat runs)
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || true)
|
||||
if [ -n "$PARENT_SHA" ]; then
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
-f "parents[]=$PARENT_SHA" \
|
||||
--jq '.sha')
|
||||
else
|
||||
# First commit on this branch — no parent
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
fi
|
||||
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
@@ -656,17 +682,123 @@ ${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
POSTED_BODY=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" --jq '.body')
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
3. A structured explanation below each screenshot covering: **Flow** (which scenario), **Steps** (exact actions taken to reach this state), **Evidence** (what this proves — pass/fail/data values). A bare "shows the page" caption is not acceptable.
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
**Verify inline rendering after posting — this is required, not optional:**
|
||||
|
||||
```bash
|
||||
# 1. Confirm the posted comment body contains inline image markdown syntax
|
||||
if ! echo "$POSTED_BODY" | grep -q '!\['; then
|
||||
echo "❌ FAIL: No inline image tags in posted comment body. Re-check IMAGE_MARKDOWN and re-post."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 2. Verify at least one raw URL actually resolves (catches wrong branch name, wrong path, etc.)
|
||||
FIRST_IMG_URL=$(echo "$POSTED_BODY" | grep -o 'https://raw.githubusercontent.com[^)]*' | head -1)
|
||||
if [ -n "$FIRST_IMG_URL" ]; then
|
||||
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" --max-time 10 "$FIRST_IMG_URL")
|
||||
if [ "$HTTP_STATUS" = "200" ]; then
|
||||
echo "✅ Inline images confirmed and raw URL resolves (HTTP 200)"
|
||||
else
|
||||
echo "❌ FAIL: Raw image URL returned HTTP $HTTP_STATUS — images will not render inline."
|
||||
echo " URL: $FIRST_IMG_URL"
|
||||
echo " Check branch name, path, and that the push succeeded."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "⚠️ Could not extract a raw URL from the comment — verify manually."
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 8: Evaluate test completeness and post a GitHub review
|
||||
|
||||
After posting the PR comment, evaluate whether the test run actually covered everything it needed to. This is NOT a rubber-stamp — be critical. Then post a formal GitHub review so the PR author and reviewers can see the verdict.
|
||||
|
||||
### 8a. Evaluate against the test plan
|
||||
|
||||
Re-read `$RESULTS_DIR/test-plan.md` (written in Step 2) and `$RESULTS_DIR/test-report.md` (written in Step 5). For each scenario in the plan, answer:
|
||||
|
||||
> **Note:** `test-report.md` is written in Step 5. If it doesn't exist, write it before proceeding here — see the Step 5 template. Do not skip evaluation because the file is missing; create it from your notes instead.
|
||||
|
||||
| Question | Pass criteria |
|
||||
|----------|--------------|
|
||||
| Was it tested? | Explicit steps were executed, not just described |
|
||||
| Is there screenshot evidence? | At least one before/after screenshot per scenario |
|
||||
| Did the core feature work correctly? | Expected state matches actual state |
|
||||
| Were negative cases tested? | At least one failure/rejection case per feature |
|
||||
| Was DB/API state verified (not just UI)? | Raw API response or DB query confirms state change |
|
||||
|
||||
Build a verdict:
|
||||
- **APPROVE** — every scenario tested, evidence present, no bugs found or all bugs are minor/known
|
||||
- **REQUEST_CHANGES** — one or more: untested scenarios, missing evidence, bugs found, data not verified
|
||||
|
||||
### 8b. Post the GitHub review
|
||||
|
||||
```bash
|
||||
EVAL_FILE=$(mktemp)
|
||||
|
||||
# === STEP A: Write header ===
|
||||
cat > "$EVAL_FILE" << 'ENDEVAL'
|
||||
## 🧪 Test Evaluation
|
||||
|
||||
### Coverage checklist
|
||||
ENDEVAL
|
||||
|
||||
# === STEP B: Append ONE line per scenario — do this BEFORE calculating verdict ===
|
||||
# Format: "- ✅ **Scenario N – name**: <what was done and verified>"
|
||||
# or "- ❌ **Scenario N – name**: <what is missing or broken>"
|
||||
# Examples:
|
||||
# echo "- ✅ **Scenario 1 – Login flow**: tested, screenshot evidence present, auth token verified via API" >> "$EVAL_FILE"
|
||||
# echo "- ❌ **Scenario 3 – Cost logging**: NOT verified in DB — UI showed entry but raw SQL query was skipped" >> "$EVAL_FILE"
|
||||
#
|
||||
# !!! IMPORTANT: append ALL scenario lines here before proceeding to STEP C !!!
|
||||
|
||||
# === STEP C: Derive verdict from the checklist — runs AFTER all lines are appended ===
|
||||
FAIL_COUNT=$(grep -c "^- ❌" "$EVAL_FILE" || true)
|
||||
if [ "$FAIL_COUNT" -eq 0 ]; then
|
||||
VERDICT="APPROVE"
|
||||
else
|
||||
VERDICT="REQUEST_CHANGES"
|
||||
fi
|
||||
|
||||
# === STEP D: Append verdict section ===
|
||||
cat >> "$EVAL_FILE" << ENDVERDICT
|
||||
|
||||
### Verdict
|
||||
ENDVERDICT
|
||||
|
||||
if [ "$VERDICT" = "APPROVE" ]; then
|
||||
echo "✅ All scenarios covered with evidence. No blocking issues found." >> "$EVAL_FILE"
|
||||
else
|
||||
echo "❌ $FAIL_COUNT scenario(s) incomplete or have confirmed bugs. See ❌ items above." >> "$EVAL_FILE"
|
||||
echo "" >> "$EVAL_FILE"
|
||||
echo "**Required before merge:** address each ❌ item above." >> "$EVAL_FILE"
|
||||
fi
|
||||
|
||||
# === STEP E: Post the review ===
|
||||
gh api "repos/${REPO}/pulls/$PR_NUMBER/reviews" \
|
||||
--method POST \
|
||||
-f body="$(cat "$EVAL_FILE")" \
|
||||
-f event="$VERDICT"
|
||||
|
||||
rm -f "$EVAL_FILE"
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Never auto-approve without checking every scenario in the test plan
|
||||
- `REQUEST_CHANGES` if ANY scenario is untested, lacks DB/API evidence, or has a confirmed bug
|
||||
- The evaluation body must list every scenario explicitly (✅ or ❌) — not just the failures
|
||||
- If you find new bugs during evaluation, add them to the request-changes body and (if `--fix` flag is set) fix them before posting
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
@@ -0,0 +1,224 @@
|
||||
---
|
||||
name: write-frontend-tests
|
||||
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
|
||||
user-invocable: true
|
||||
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Write Frontend Tests
|
||||
|
||||
Analyze the current branch's frontend changes, plan integration tests, and write them.
|
||||
|
||||
## References
|
||||
|
||||
Before writing any tests, read the testing rules and conventions:
|
||||
|
||||
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
|
||||
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
|
||||
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
|
||||
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
|
||||
|
||||
## Step 1: Identify changed frontend files
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${ARGUMENTS:-dev}"
|
||||
cd autogpt_platform/frontend
|
||||
|
||||
# Get changed frontend files (excluding generated, config, and test files)
|
||||
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
|
||||
| grep -v '__generated__' \
|
||||
| grep -v '__tests__' \
|
||||
| grep -v '\.test\.' \
|
||||
| grep -v '\.stories\.' \
|
||||
| grep -v '\.spec\.'
|
||||
```
|
||||
|
||||
Also read the diff to understand what changed:
|
||||
|
||||
```bash
|
||||
git diff "$BASE_BRANCH"...HEAD --stat -- src/
|
||||
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
```
|
||||
|
||||
## Step 2: Categorize changes and find test targets
|
||||
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Hooks with non-trivial business logic
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
|
||||
## Step 3: Check for existing tests
|
||||
|
||||
For each test target, check if tests already exist:
|
||||
|
||||
```bash
|
||||
# For a page at src/app/(platform)/library/page.tsx
|
||||
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
|
||||
|
||||
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
|
||||
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
|
||||
```
|
||||
|
||||
Note which targets have no tests (need new files) vs which have tests that need updating.
|
||||
|
||||
## Step 4: Identify API endpoints used
|
||||
|
||||
For each test target, find which API hooks are used:
|
||||
|
||||
```bash
|
||||
# Find generated API hook imports in the changed files
|
||||
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
|
||||
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
|
||||
```
|
||||
|
||||
For each API hook found, locate the corresponding MSW handler:
|
||||
|
||||
```bash
|
||||
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
|
||||
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
|
||||
```
|
||||
|
||||
List every MSW handler you will need (200 for happy path, 4xx for error paths).
|
||||
|
||||
## Step 5: Write the test plan
|
||||
|
||||
Before writing code, output a plan as a numbered list:
|
||||
|
||||
```
|
||||
Test plan for [branch name]:
|
||||
|
||||
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
|
||||
- Renders page with agent list (MSW 200)
|
||||
- Shows loading state
|
||||
- Shows error state (MSW 422)
|
||||
- Handles empty agent list
|
||||
|
||||
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
|
||||
- Filters agents by search query
|
||||
- Shows no results message
|
||||
- Clears search
|
||||
|
||||
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
|
||||
- Add test for new "duplicate" action
|
||||
```
|
||||
|
||||
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
|
||||
|
||||
## Step 6: Write the tests
|
||||
|
||||
For each test file in the plan, follow these conventions:
|
||||
|
||||
### File structure
|
||||
|
||||
```tsx
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
// Import MSW handlers for endpoints the page uses
|
||||
import {
|
||||
getGetV2ListLibraryAgentsMockHandler200,
|
||||
getGetV2ListLibraryAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/library/library.msw";
|
||||
// Import the component under test
|
||||
import LibraryPage from "../page";
|
||||
|
||||
describe("LibraryPage", () => {
|
||||
test("renders agent list from API", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler200());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/my agents/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows error state on API failure", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler422());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/error/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Rules
|
||||
|
||||
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
|
||||
- Use `server.use()` to set up MSW handlers BEFORE rendering
|
||||
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
|
||||
- Use `getBy*` only for elements that are immediately present in the DOM
|
||||
- Use `screen` queries — do NOT destructure from `render()`
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
|
||||
### Test location
|
||||
|
||||
```
|
||||
# For pages: __tests__/ next to page.tsx
|
||||
src/app/(platform)/library/__tests__/main.test.tsx
|
||||
|
||||
# For complex standalone components: __tests__/ inside component folder
|
||||
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
|
||||
|
||||
# For pure helpers: co-located .test.ts
|
||||
src/app/(platform)/library/helpers.test.ts
|
||||
```
|
||||
|
||||
### Custom MSW overrides
|
||||
|
||||
When the auto-generated faker data is not enough, override with specific data:
|
||||
|
||||
```tsx
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [
|
||||
{ id: "1", name: "Test Agent", description: "A test agent" },
|
||||
],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
);
|
||||
```
|
||||
|
||||
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
|
||||
|
||||
## Step 7: Run and verify
|
||||
|
||||
After writing all tests:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
Then run the full checks:
|
||||
|
||||
```bash
|
||||
pnpm format
|
||||
pnpm lint
|
||||
pnpm types
|
||||
```
|
||||
78
.github/workflows/classic-autogpt-ci.yml
vendored
78
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,19 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +27,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +54,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +81,13 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +101,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
tests/unit tests/integration 2>&1 | tee test_output.txt
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
|
||||
|
||||
test_failure=${PIPESTATUS[0]}
|
||||
|
||||
|
||||
44
.github/workflows/classic-autogpts-ci.yml
vendored
44
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -10,10 +10,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -21,10 +20,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +33,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +49,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
256
.github/workflows/classic-benchmark-ci.yml
vendored
256
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,18 +1,24 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -23,95 +29,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -124,53 +51,120 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
189
.github/workflows/classic-forge-ci.yml
vendored
189
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,131 +23,60 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: poetry run playwright install chromium
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +90,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -269,12 +269,14 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
fi
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
@@ -287,11 +289,13 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-backend
|
||||
files: ./autogpt_platform/backend/coverage.xml
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,3 +148,11 @@ jobs:
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend
|
||||
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml
|
||||
|
||||
25
.github/workflows/platform-fullstack-ci.yml
vendored
25
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -179,21 +179,30 @@ jobs:
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
|
||||
# (docker compose config on some versions drops this arg)
|
||||
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
|
||||
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
|
||||
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
|
||||
fi
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
NEXT_PUBLIC_SOURCEMAPS: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
@@ -279,6 +288,11 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -289,6 +303,15 @@ jobs:
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend-e2e
|
||||
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
|
||||
disable_search: true
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -16,6 +17,7 @@ log-ingestion.txt
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
!autogpt_platform/frontend/public/notification.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,9 +183,13 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
|
||||
36
.gitleaks.toml
Normal file
36
.gitleaks.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
title = "AutoGPT Gitleaks Config"
|
||||
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[allowlist]
|
||||
description = "Global allowlist"
|
||||
paths = [
|
||||
# Template/example env files (no real secrets)
|
||||
'''\.env\.(default|example|template)$''',
|
||||
# Lock files
|
||||
'''pnpm-lock\.yaml$''',
|
||||
'''poetry\.lock$''',
|
||||
# Secrets baseline
|
||||
'''\.secrets\.baseline$''',
|
||||
# Build artifacts and caches (should not be committed)
|
||||
'''__pycache__/''',
|
||||
'''classic/frontend/build/''',
|
||||
# Docker dev setup (local dev JWTs/keys only)
|
||||
'''autogpt_platform/db/docker/''',
|
||||
# Load test configs (dev JWTs)
|
||||
'''load-tests/configs/''',
|
||||
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
|
||||
'''(_test|test_.*|conftest)\.py$''',
|
||||
# Documentation (only contains placeholder keys in curl/API examples)
|
||||
'''docs/.*\.md$''',
|
||||
# Firebase config (public API keys by design)
|
||||
'''google-services\.json$''',
|
||||
'''classic/frontend/(lib|web)/''',
|
||||
]
|
||||
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
|
||||
regexes = [
|
||||
'''dvziYgz0KSK8FENhju0ZYi8''',
|
||||
# LLM model name enum values falsely flagged as API keys
|
||||
'''Llama-\d.*Instruct''',
|
||||
]
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -23,9 +23,15 @@ repos:
|
||||
- id: detect-secrets
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
args: ["--baseline", ".secrets.baseline"]
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
name: Detect secrets (gitleaks)
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -84,51 +90,16 @@ repos:
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
|
||||
poetry -C classic install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
@@ -223,26 +194,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -256,26 +211,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -311,29 +253,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -360,26 +283,9 @@ repos:
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# name: Run tests - Classic (excl. slow tests)
|
||||
# alias: pytest-classic
|
||||
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
|
||||
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
467
.secrets.baseline
Normal file
467
.secrets.baseline
Normal file
@@ -0,0 +1,467 @@
|
||||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitLabTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "IPPublicDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "OpenAIDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "PypiTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TelegramBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.regex.should_exclude_file",
|
||||
"pattern": [
|
||||
"\\.env$",
|
||||
"pnpm-lock\\.yaml$",
|
||||
"\\.env\\.(default|example|template)$",
|
||||
"__pycache__",
|
||||
"_test\\.py$",
|
||||
"test_.*\\.py$",
|
||||
"conftest\\.py$",
|
||||
"poetry\\.lock$",
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
],
|
||||
"results": {
|
||||
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
|
||||
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
|
||||
"is_verified": false,
|
||||
"line_number": 289
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
|
||||
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
|
||||
"is_verified": false,
|
||||
"line_number": 29
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
|
||||
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
|
||||
"is_verified": false,
|
||||
"line_number": 12
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/checks.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 108
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/ci.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
|
||||
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
|
||||
"is_verified": false,
|
||||
"line_number": 123
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
|
||||
"is_verified": false,
|
||||
"line_number": 42
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
|
||||
"is_verified": false,
|
||||
"line_number": 344
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
|
||||
"is_verified": false,
|
||||
"line_number": 534
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/docs.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
|
||||
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
|
||||
"is_verified": false,
|
||||
"line_number": 203
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
|
||||
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
|
||||
"is_verified": false,
|
||||
"line_number": 57
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
|
||||
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
|
||||
"is_verified": false,
|
||||
"line_number": 53
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/medium.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
|
||||
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
|
||||
"is_verified": false,
|
||||
"line_number": 131
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
|
||||
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
|
||||
"is_verified": false,
|
||||
"line_number": 100
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/talking_head.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
|
||||
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
|
||||
"is_verified": false,
|
||||
"line_number": 113
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
|
||||
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
|
||||
"is_verified": false,
|
||||
"line_number": 17
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/util/cache.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/util/cache.py",
|
||||
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
|
||||
"is_verified": false,
|
||||
"line_number": 449
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
|
||||
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
|
||||
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 8
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 7
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 192
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
|
||||
"is_verified": false,
|
||||
"line_number": 102
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 103
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
|
||||
"is_verified": false,
|
||||
"line_number": 73
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
|
||||
"is_verified": false,
|
||||
"line_number": 75
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
|
||||
"is_verified": false,
|
||||
"line_number": 77
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
|
||||
"is_verified": false,
|
||||
"line_number": 79
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
|
||||
"is_verified": false,
|
||||
"line_number": 81
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
|
||||
"is_verified": false,
|
||||
"line_number": 83
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/constants.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
|
||||
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
|
||||
"is_verified": false,
|
||||
"line_number": 10
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
|
||||
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
|
||||
"is_verified": false,
|
||||
"line_number": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-02T13:10:54Z"
|
||||
}
|
||||
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
@@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
|
||||
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
|
||||
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache dashboard results for 30 seconds per unique filter combination.
|
||||
# The table is append-only so stale reads are acceptable for analytics.
|
||||
_DASHBOARD_CACHE_TTL = 30
|
||||
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
|
||||
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
cache_key = (start, end, provider, user_id)
|
||||
cached = _dashboard_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
result = await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
_dashboard_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import PlatformCostDashboard
|
||||
|
||||
from . import platform_cost_routes
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
# Clear TTL cache so each test starts cold.
|
||||
platform_cost_routes._dashboard_cache.clear()
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_cache_hit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Second identical request returns cached result without calling the DB again."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mock_fn = mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
client.get("/platform-costs/dashboard")
|
||||
client.get("/platform-costs/dashboard")
|
||||
|
||||
mock_fn.assert_awaited_once() # second request hit the cache
|
||||
@@ -9,11 +9,14 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
@@ -143,4 +158,102 @@ async def reset_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s",
|
||||
request.user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved_email = None
|
||||
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
|
||||
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email,
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -89,6 +89,7 @@ def test_get_rate_limit(
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure returns 404 (user unverifiable)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -15,7 +15,8 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -111,6 +112,11 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -150,6 +156,8 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
@@ -389,60 +397,78 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
messages = [message.model_dump() for message in page.messages]
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in page.session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
metadata=page.session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -456,8 +482,9 @@ async def get_copilot_usage(
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
@@ -465,6 +492,7 @@ async def get_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +544,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -550,10 +578,13 @@ async def reset_copilot_usage(
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -629,6 +660,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -739,7 +771,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
@@ -834,6 +866,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -331,14 +332,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -526,3 +541,41 @@ def test_create_session_rejects_nested_metadata(
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
@@ -481,6 +481,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
include=library_agent_include(
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_onboarding_profile_success(mocker):
|
||||
mock_extract = mocker.patch(
|
||||
"backend.api.features.v1.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.v1.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
|
||||
user_name="John",
|
||||
user_role="Founder/CEO",
|
||||
pain_points=["Finding leads"],
|
||||
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
|
||||
)
|
||||
mock_upsert.return_value = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={
|
||||
"user_name": "John",
|
||||
"user_role": "Founder/CEO",
|
||||
"pain_points": ["Finding leads", "Email & outreach"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once()
|
||||
|
||||
|
||||
def test_onboarding_profile_missing_fields():
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={"user_name": "John"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -63,12 +63,17 @@ from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
complete_onboarding_step,
|
||||
complete_re_run_agent,
|
||||
format_onboarding_for_extraction,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.tally import extract_business_understanding
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
@@ -282,35 +287,33 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding status check."""
|
||||
class OnboardingProfileRequest(pydantic.BaseModel):
|
||||
"""Request body for onboarding profile submission."""
|
||||
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
user_name: str = pydantic.Field(min_length=1, max_length=100)
|
||||
user_role: str = pydantic.Field(min_length=1, max_length=100)
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding completion check."""
|
||||
|
||||
is_completed: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
"/onboarding/completed",
|
||||
summary="Check if onboarding is completed",
|
||||
tags=["onboarding", "public"],
|
||||
response_model=OnboardingStatusResponse,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled(
|
||||
async def is_onboarding_completed(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
|
||||
)
|
||||
|
||||
|
||||
@@ -325,6 +328,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await reset_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/onboarding/profile",
|
||||
summary="Submit onboarding profile",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def submit_onboarding_profile(
|
||||
data: OnboardingProfileRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
formatted = format_onboarding_for_extraction(
|
||||
user_name=data.user_name,
|
||||
user_role=data.user_role,
|
||||
pain_points=data.pain_points,
|
||||
)
|
||||
|
||||
try:
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
except Exception:
|
||||
understanding_input = BusinessUnderstandingInput.model_construct()
|
||||
|
||||
# Ensure the direct fields are set even if LLM missed them
|
||||
understanding_input.user_name = data.user_name
|
||||
understanding_input.user_role = data.user_role
|
||||
if not understanding_input.pain_points:
|
||||
understanding_input.pain_points = data.pain_points
|
||||
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
|
||||
@@ -12,7 +12,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
|
||||
file_count: int
|
||||
|
||||
|
||||
class WorkspaceFileItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
files: list[WorkspaceFileItem]
|
||||
offset: int = 0
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
operation_id="getWorkspaceDownloadFileById",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -158,6 +175,7 @@ async def download_file(
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
operation_id="deleteWorkspaceFile",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -183,6 +201,7 @@ async def delete_workspace_file(
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
operation_id="uploadWorkspaceFile",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -196,6 +215,9 @@ async def upload_file(
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
# Empty-string session_id drops session scoping; normalize to None.
|
||||
session_id = session_id or None
|
||||
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
@@ -250,16 +272,27 @@ async def upload_file(
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
try:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to soft-delete over-quota file {workspace_file.id} "
|
||||
f"in workspace {workspace.id}: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
@@ -281,6 +314,7 @@ async def upload_file(
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
operation_id="getWorkspaceStorageUsage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -301,3 +335,57 @@ async def get_storage_usage(
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files",
|
||||
summary="List workspace files",
|
||||
operation_id="listWorkspaceFiles",
|
||||
)
|
||||
async def list_workspace_files(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
session_id: str | None = Query(default=None),
|
||||
limit: int = Query(default=200, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
List files in the user's workspace.
|
||||
|
||||
When session_id is provided, only files for that session are returned.
|
||||
Otherwise, all files across sessions are listed. Results are paginated
|
||||
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
|
||||
"""
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Treat empty-string session_id the same as omitted — an empty value
|
||||
# would otherwise silently list files across every session instead of
|
||||
# scoping to one.
|
||||
session_id = session_id or None
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
include_all = session_id is None
|
||||
# Fetch one extra to compute has_more without a separate count query.
|
||||
files = await manager.list_files(
|
||||
limit=limit + 1,
|
||||
offset=offset,
|
||||
include_all_sessions=include_all,
|
||||
)
|
||||
has_more = len(files) > limit
|
||||
page = files[:limit]
|
||||
|
||||
return ListFilesResponse(
|
||||
files=[
|
||||
WorkspaceFileItem(
|
||||
id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
metadata=f.metadata or {},
|
||||
created_at=f.created_at.isoformat(),
|
||||
)
|
||||
for f in page
|
||||
],
|
||||
offset=offset,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
@@ -1,48 +1,28 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.api.features.workspace.routes import router
|
||||
from backend.data.workspace import Workspace, WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
"""Mirror the production ValueError → 400 mapping from the REST app."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
|
||||
return Workspace(
|
||||
id="ws-001",
|
||||
user_id=user_id,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"storage_path": "local://test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _make_file_mock(**overrides) -> MagicMock:
|
||||
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"metadata": {},
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock(spec=WorkspaceFile)
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
# -- list_workspace_files tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
|
||||
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["files"]) == 2
|
||||
assert data["has_more"] is False
|
||||
assert data["offset"] == 0
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_scopes_to_session_when_provided(
|
||||
mock_manager_cls, mock_get_workspace, test_user_id
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?session_id=sess-123")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["files"] == []
|
||||
assert data["has_more"] is False
|
||||
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=False
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["files"][0]["metadata"] == {}
|
||||
|
||||
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_instance.write_file.assert_called_once()
|
||||
call_kwargs = mock_instance.write_file.call_args
|
||||
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.side_effect = ValueError("File already exists at path")
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("dup.txt", b"content", "text/plain")},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
# -- Restored upload/download/delete security + invariant tests --
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
_MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-001",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
name="hello.txt",
|
||||
path="/sessions/sess-1/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_happy_path(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_exceeds_max_file_size(mocker):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
def test_upload_post_write_quota_race(mocker):
|
||||
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_any_extension(mocker):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_file_without_extension(mocker):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_strips_path_components(mocker):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_download_file_not_found(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_success(mocker):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_not_found(mocker):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_no_workspace(mocker):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"""write_file raises ValueError("File too large: …") → must map to 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "File too large" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"""Non-'File too large' ValueErrors from write_file stay as 409."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.text
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_true_when_limit_exceeded(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
# Backend was asked for limit+1=3, and returned exactly 3 items.
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt"),
|
||||
_make_file(id="f2", name="b.txt"),
|
||||
_make_file(id="f3", name="c.txt"),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is True
|
||||
assert len(data["files"]) == 2
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=3, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_false_when_exactly_page_size(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Exactly `limit` rows means we're on the last page — has_more=False."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
assert len(data["files"]) == 2
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?offset=50&limit=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["offset"] == 50
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -0,0 +1,712 @@
|
||||
"""Unit tests for merge_stats cost tracking in individual blocks.
|
||||
|
||||
Covers the exa code_context, exa contents, and apollo organization blocks
|
||||
to verify provider cost is correctly extracted and reported.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_EXA_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_EXA_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_EXA_CREDENTIALS.provider,
|
||||
"id": TEST_EXA_CREDENTIALS.id,
|
||||
"type": TEST_EXA_CREDENTIALS.type,
|
||||
"title": TEST_EXA_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaCodeContextBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_float_cost(self):
|
||||
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "how to use hooks",
|
||||
"response": "Here are some examples...",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="how to use hooks",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_dollars_does_not_raise(self):
|
||||
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "query",
|
||||
"response": "response",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
merge_calls: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merge_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_is_tracked(self):
|
||||
"""A zero cost_dollars string '0.0' should still be recorded."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "query",
|
||||
"response": "...",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaContentsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_cost_dollars_total(self):
|
||||
"""provider_cost equals response.cost_dollars.total when present."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = cost_dollars
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_cost_dollars_absent(self):
|
||||
"""When response.cost_dollars is None, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = None
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchOrganizationsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_org_count(self):
|
||||
"""provider_cost == number of returned organizations, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Organization
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
|
||||
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_orgs,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(3.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_org_list_tracks_zero(self):
|
||||
"""An empty organization list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JinaEmbeddingBlock — token count from usage.total_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJinaEmbeddingBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_token_count(self):
|
||||
"""provider token count is recorded when API returns usage.total_tokens."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
"usage": {"total_tokens": 42},
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello world"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].input_token_count == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_usage_absent(self):
|
||||
"""When API response omits usage field, merge_stats is not called."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UnrealTextToSpeechBlock — character count from input text length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
test_text = "Hello, world!"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text=test_text,
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text="",
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GoogleMapsSearchBlock — item count from search_places results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleMapsSearchBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_place_count(self):
|
||||
"""provider_cost equals number of returned places, type == 'items'."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
|
||||
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=fake_places,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="coffee shops",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 4.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_tracks_zero(self):
|
||||
"""Zero places returned results in provider_cost=0.0."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="nothing here",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SmartLeadAddLeadsBlock — item count from lead_list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartLeadAddLeadsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_lead_count(self):
|
||||
"""provider_cost equals number of leads uploaded, type == 'items'."""
|
||||
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsToCampaignResponse,
|
||||
LeadInput,
|
||||
)
|
||||
|
||||
block = AddLeadToCampaignBlock()
|
||||
|
||||
fake_leads = [
|
||||
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
|
||||
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
|
||||
]
|
||||
fake_response = AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=2,
|
||||
total_leads=2,
|
||||
block_count=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
invalid_emails=[],
|
||||
already_added_to_campaign=0,
|
||||
unsubscribed_leads=[],
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
bounce_count=0,
|
||||
)
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
AddLeadToCampaignBlock,
|
||||
"add_leads_to_campaign",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = AddLeadToCampaignBlock.Input(
|
||||
campaign_id=123,
|
||||
lead_list=fake_leads,
|
||||
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=SL_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 2.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPeopleBlock — item count from people list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchPeopleBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_people_count(self):
|
||||
"""provider_cost equals number of returned people, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_people,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(5.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_people_list_tracks_zero(self):
|
||||
"""An empty people list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
"""Tests for cost tracking in Exa blocks.
|
||||
|
||||
Covers the cost_dollars → provider_cost → merge_stats path for both
|
||||
ExaContentsBlock and ExaCodeContextBlock.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
class TestExaCodeContextCostTracking:
|
||||
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_cost_string_is_parsed_and_merged(self):
|
||||
"""A numeric cost string like '0.005' is merged as provider_cost."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "test query",
|
||||
"response": "some code",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
assert any(k == "cost_dollars" for k, _ in outputs)
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_string_does_not_raise(self):
|
||||
"""A non-numeric cost_dollars value is swallowed silently."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "test",
|
||||
"response": "code",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
# No merge_stats call because float() raised ValueError
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_string_is_merged(self):
|
||||
"""'0.0' is a valid cost — should still be tracked."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "free query",
|
||||
"response": "result",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaContentsCostTracking:
|
||||
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_dollars_is_merged(self):
|
||||
"""A total of 0.0 (free tier) should still be merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCreateResearchBlock — cost_dollars from completed poll response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMPLETED_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "completed",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
"finishedAt": 1700000060000,
|
||||
"costDollars": {
|
||||
"total": 0.05,
|
||||
"numSearches": 3,
|
||||
"numPages": 10,
|
||||
"reasoningTokens": 500,
|
||||
},
|
||||
"output": {"content": "Research findings...", "parsed": None},
|
||||
}
|
||||
|
||||
PENDING_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "pending",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestExaCreateResearchBlockCostTracking:
|
||||
"""ExaCreateResearchBlock merges cost from completed poll response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total when poll returns completed."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed response has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaGetResearchBlock — cost_dollars from single GET response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaGetResearchBlockCostTracking:
|
||||
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_from_completed_research(self):
|
||||
"""merge_stats called with provider_cost=total when research has costDollars."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaWaitForResearchBlock — cost_dollars from polling response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaWaitForResearchBlockCostTracking:
|
||||
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total once polling returns completed."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(places)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
@@ -13,6 +14,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -205,6 +207,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Z.ai (Zhipu) models
|
||||
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
|
||||
ZAI_GLM_4_5 = "z-ai/glm-4.5"
|
||||
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
|
||||
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
|
||||
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
|
||||
ZAI_GLM_4_6 = "z-ai/glm-4.6"
|
||||
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
|
||||
ZAI_GLM_4_7 = "z-ai/glm-4.7"
|
||||
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
|
||||
ZAI_GLM_5 = "z-ai/glm-5"
|
||||
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
|
||||
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
@@ -630,6 +645,43 @@ MODEL_METADATA = {
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# https://openrouter.ai/models?q=z-ai
|
||||
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
|
||||
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
|
||||
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
|
||||
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6: ModelMetadata(
|
||||
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7: ModelMetadata(
|
||||
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
|
||||
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_5: ModelMetadata(
|
||||
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
@@ -687,6 +739,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -721,6 +774,35 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
|
||||
attribute. We use try/except AttributeError so that if the SDK ever drops
|
||||
or renames that attribute, the warning is visible in logs rather than
|
||||
silently degrading to no cost tracking.
|
||||
"""
|
||||
try:
|
||||
raw_resp = response._response # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"OpenAI SDK response missing _response attribute"
|
||||
" — OpenRouter cost tracking unavailable"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if not cost_header:
|
||||
return None
|
||||
cost = float(cost_header)
|
||||
if not math.isfinite(cost):
|
||||
return None
|
||||
return cost
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
@@ -1053,6 +1135,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=extract_openrouter_cost(response),
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1360,6 +1443,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
last_attempt_cost: float | None = None
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
@@ -1377,12 +1461,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
# Merge token counts for every attempt (each call costs tokens).
|
||||
# provider_cost (actual USD) is tracked separately and only merged
|
||||
# on success to avoid double-counting across retries.
|
||||
token_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
self.merge_stats(token_stats)
|
||||
last_attempt_cost = llm_response.provider_cost
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
@@ -1451,6 +1538,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1471,6 +1559,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.lead_list)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
|
||||
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description=(
|
||||
"Database hostname or IP address. "
|
||||
"Treated as a secret to avoid leaking infrastructure details. "
|
||||
"Private/internal IPs are blocked (SSRF protection)."
|
||||
),
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
truncated: bool = SchemaField(
|
||||
description=(
|
||||
"True when the result set was capped by max_rows, "
|
||||
"indicating additional rows exist in the database"
|
||||
)
|
||||
)
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
("truncated", False),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
False,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected, truncated = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
yield "truncated", truncated
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield (
|
||||
"error",
|
||||
self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
),
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield (
|
||||
"error",
|
||||
(
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
|
||||
"LOAD",
|
||||
# MySQL REPLACE is INSERT-or-UPDATE; data modification
|
||||
"REPLACE",
|
||||
# ANSI MERGE (UPSERT) modifies data
|
||||
"MERGE",
|
||||
# MSSQL BULK INSERT loads external files into tables
|
||||
"BULK",
|
||||
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
|
||||
"EXEC",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
|
||||
# login_timeout). This bounds how long the driver waits to establish a TCP
|
||||
# connection to the database server. It is separate from the per-statement
|
||||
# timeout configured via SET commands inside _configure_session().
|
||||
_CONNECT_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# NaN / Infinity are not valid JSON numbers; serialize as strings.
|
||||
if value.is_nan() or value.is_infinite():
|
||||
return str(value)
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect.
|
||||
|
||||
Timeout limitations by database:
|
||||
|
||||
* **PostgreSQL** – ``statement_timeout`` reliably cancels any running
|
||||
statement (SELECT or DML) after the configured duration.
|
||||
* **MySQL** – ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
|
||||
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
|
||||
this hint; they rely on the server's ``wait_timeout`` /
|
||||
``interactive_timeout`` instead. There is no session-level setting in
|
||||
MySQL that reliably cancels long-running writes.
|
||||
* **MSSQL** – ``SET LOCK_TIMEOUT`` only limits how long the server waits
|
||||
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
|
||||
joins) that do not block on locks will *not* be cancelled. MSSQL has
|
||||
no session-level ``statement_timeout`` equivalent; the closest
|
||||
mechanism is Resource Governor (requires sysadmin configuration) or
|
||||
``CONTEXT_INFO``-based external monitoring.
|
||||
|
||||
Note: SQLite is not supported by this block. The ``_configure_session``
|
||||
function is a no-op for unrecognised dialect names, so an SQLite engine
|
||||
would skip all SET commands silently. The block's ``DatabaseType`` enum
|
||||
intentionally excludes SQLite.
|
||||
"""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
|
||||
# CPU-bound queries without lock contention are NOT cancelled.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a query inside an explicit transaction, returning results.
|
||||
|
||||
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
|
||||
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
|
||||
indicating that additional rows may exist in the result set.
|
||||
"""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
truncated = len(rows) == max_rows
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
try:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected, truncated
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(
|
||||
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
|
||||
)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_uses_last_attempt_only(self):
|
||||
"""provider_cost is only merged from the final successful attempt.
|
||||
|
||||
Intermediate retry costs are intentionally dropped to avoid
|
||||
double-counting: the cost of failed attempts is captured in
|
||||
last_attempt_cost only when the loop eventually succeeds.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Only the final successful attempt's cost is merged
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
@@ -987,3 +1047,63 @@ class TestLlmModelMissing:
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
|
||||
class TestExtractOpenRouterCost:
|
||||
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
|
||||
|
||||
def _mk_response(self, headers: dict | None):
|
||||
response = MagicMock()
|
||||
if headers is None:
|
||||
response._response = None
|
||||
else:
|
||||
raw = MagicMock()
|
||||
raw.headers = headers
|
||||
response._response = raw
|
||||
return response
|
||||
|
||||
def test_extracts_numeric_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "0.0042"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0042
|
||||
|
||||
def test_returns_none_when_header_missing(self):
|
||||
response = self._mk_response({})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_empty_string(self):
|
||||
response = self._mk_response({"x-total-cost": ""})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_non_numeric(self):
|
||||
response = self._mk_response({"x-total-cost": "not-a-number"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_no_response_attr(self):
|
||||
response = MagicMock(spec=[]) # no _response attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_is_none(self):
|
||||
response = self._mk_response(None)
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_has_no_headers(self):
|
||||
response = MagicMock()
|
||||
response._response = MagicMock(spec=[]) # no headers attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_zero_for_zero_cost(self):
|
||||
"""Zero-cost is a valid value (free tier) and must not become None."""
|
||||
response = self._mk_response({"x-total-cost": "0"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0
|
||||
|
||||
def test_returns_none_for_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_negative_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "-inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_nan(self):
|
||||
response = self._mk_response({"x-total-cost": "nan"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.text)),
|
||||
provider_cost_type="characters",
|
||||
)
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,799 @@
|
||||
"""Unit tests for baseline service pure-logic helpers.
|
||||
|
||||
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_conversation_updater,
|
||||
_BaselineStreamState,
|
||||
_compress_session_messages,
|
||||
_ThinkingStripper,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.prompt import CompressResult
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
assert state.pending_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
assert state.turn_completion_tokens == 0
|
||||
assert state.text_block_id # Should be a UUID string
|
||||
|
||||
def test_mutable_fields(self):
|
||||
state = _BaselineStreamState()
|
||||
state.assistant_text = "hello"
|
||||
state.turn_prompt_tokens = 100
|
||||
state.turn_completion_tokens = 50
|
||||
assert state.assistant_text == "hello"
|
||||
assert state.turn_prompt_tokens == 100
|
||||
assert state.turn_completion_tokens == 50
|
||||
|
||||
|
||||
class TestBaselineConversationUpdater:
|
||||
"""Tests for _baseline_conversation_updater which updates the OpenAI
|
||||
message list and transcript builder after each LLM call."""
|
||||
|
||||
def _make_transcript_builder(self) -> TranscriptBuilder:
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("test question")
|
||||
return builder
|
||||
|
||||
def test_text_only_response(self):
|
||||
"""When the LLM returns text without tool calls, the updater appends
|
||||
a single assistant message and records it in the transcript."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Hello, world!",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Hello, world!"
|
||||
# Transcript should have user + assistant
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
def test_tool_calls_response(self):
|
||||
"""When the LLM returns tool calls, the updater appends the assistant
|
||||
message with tool_calls and tool result messages."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Let me search...",
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1",
|
||||
tool_name="search",
|
||||
content="Found result",
|
||||
),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Messages: assistant (with tool_calls) + tool result
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Let me search..."
|
||||
assert len(messages[0]["tool_calls"]) == 1
|
||||
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
|
||||
assert messages[1]["role"] == "tool"
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[1]["content"] == "Found result"
|
||||
|
||||
# Transcript: user + assistant(tool_use) + user(tool_result)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_tool_calls_without_text(self):
|
||||
"""Tool calls without accompanying text should still work."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="run", arguments="{}"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "content" not in messages[0] # No text content
|
||||
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
|
||||
|
||||
def test_no_text_no_tools(self):
|
||||
"""When the response has no text and no tool calls, nothing is appended."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 0
|
||||
# Only the user entry from setup
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool calls in a single response are all recorded."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
|
||||
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# 1 assistant + 2 tool results
|
||||
assert len(messages) == 3
|
||||
assert len(messages[0]["tool_calls"]) == 2
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[2]["tool_call_id"] == "tc_2"
|
||||
|
||||
def test_invalid_tool_arguments_handled(self):
|
||||
"""Tool call with invalid JSON arguments: the arguments field is
|
||||
stored as-is in the message, and orjson failure falls back to {}
|
||||
in the transcript content_blocks."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should not raise — invalid JSON falls back to {} in transcript
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
|
||||
|
||||
|
||||
class TestCompressSessionMessagesPreservesToolCalls:
|
||||
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
|
||||
|
||||
Compression serialises ChatMessage to dict for ``compress_context`` and
|
||||
reifies the result back to ChatMessage. A regression that drops
|
||||
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
|
||||
list and break downstream tool-execution rounds.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compressed_output_keeps_tool_calls_and_ids(self):
|
||||
# Simulate compression that returns a summary + the most recent
|
||||
# assistant(tool_call) + tool(tool_result) intact.
|
||||
summary = {"role": "system", "content": "prior turns: user asked X"}
|
||||
assistant_with_tc = {
|
||||
"role": "assistant",
|
||||
"content": "calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"y"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "tc_abc",
|
||||
"content": "search result",
|
||||
}
|
||||
|
||||
compress_result = CompressResult(
|
||||
messages=[summary, assistant_with_tc, tool_result],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
# Input: messages that should be compressed.
|
||||
input_messages = [
|
||||
ChatMessage(role="user", content="q1"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q":"y"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
tool_call_id="tc_abc",
|
||||
content="search result",
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=compress_result),
|
||||
):
|
||||
compressed = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
# Summary, assistant(tool_calls), tool(tool_call_id).
|
||||
assert len(compressed) == 3
|
||||
# Assistant message must keep its tool_calls intact.
|
||||
assistant_msg = compressed[1]
|
||||
assert assistant_msg.role == "assistant"
|
||||
assert assistant_msg.tool_calls is not None
|
||||
assert len(assistant_msg.tool_calls) == 1
|
||||
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
|
||||
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
|
||||
# Tool-role message must keep tool_call_id for OpenAI linkage.
|
||||
tool_msg = compressed[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.tool_call_id == "tc_abc"
|
||||
assert tool_msg.content == "search result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncompressed_passthrough_keeps_fields(self):
|
||||
"""When compression is a no-op (was_compacted=False), the original
|
||||
messages must be returned unchanged — including tool_calls."""
|
||||
input_messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="c",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
|
||||
]
|
||||
|
||||
noop_result = CompressResult(
|
||||
messages=[], # ignored when was_compacted=False
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=noop_result),
|
||||
):
|
||||
out = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
assert out is input_messages # same list returned
|
||||
assert out[0].tool_calls is not None
|
||||
assert out[0].tool_calls[0]["id"] == "t1"
|
||||
assert out[1].tool_call_id == "t1"
|
||||
|
||||
|
||||
# ---- _ThinkingStripper tests ---- #
|
||||
|
||||
|
||||
def test_thinking_stripper_basic_thinking_tag() -> None:
|
||||
"""<thinking>...</thinking> blocks are fully stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("<thinking>internal reasoning here</thinking>Hello!") == "Hello!"
|
||||
|
||||
|
||||
def test_thinking_stripper_internal_reasoning_tag() -> None:
|
||||
"""<internal_reasoning>...</internal_reasoning> blocks (Gemini) are stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("<internal_reasoning>step by step</internal_reasoning>Answer")
|
||||
== "Answer"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_split_across_chunks() -> None:
|
||||
"""Tags split across multiple chunks are handled correctly."""
|
||||
s = _ThinkingStripper()
|
||||
out = s.process("Hello <thin")
|
||||
out += s.process("king>secret</thinking> world")
|
||||
assert out == "Hello world"
|
||||
|
||||
|
||||
def test_thinking_stripper_plain_text_preserved() -> None:
|
||||
"""Plain text with the word 'thinking' is not stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("I am thinking about this problem")
|
||||
== "I am thinking about this problem"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_multiple_blocks() -> None:
|
||||
"""Multiple reasoning blocks in one stream are all stripped."""
|
||||
s = _ThinkingStripper()
|
||||
result = s.process(
|
||||
"A<thinking>x</thinking>B<internal_reasoning>y</internal_reasoning>C"
|
||||
)
|
||||
assert result == "ABC"
|
||||
|
||||
|
||||
def test_thinking_stripper_flush_discards_unclosed() -> None:
|
||||
"""Unclosed reasoning block is discarded on flush."""
|
||||
s = _ThinkingStripper()
|
||||
s.process("Start<thinking>never closed")
|
||||
flushed = s.flush()
|
||||
assert "never closed" not in flushed
|
||||
|
||||
|
||||
def test_thinking_stripper_empty_block() -> None:
|
||||
"""Empty reasoning blocks are handled gracefully."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("Before<thinking></thinking>After") == "BeforeAfter"
|
||||
|
||||
|
||||
# ---- _filter_tools_by_permissions tests ---- #
|
||||
|
||||
|
||||
def _make_tool(name: str) -> ChatCompletionToolParam:
|
||||
"""Build a minimal OpenAI ChatCompletionToolParam."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={"name": name, "parameters": {}},
|
||||
)
|
||||
|
||||
|
||||
class TestFilterToolsByPermissions:
|
||||
"""Tests for _filter_tools_by_permissions."""
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_empty_permissions_returns_all(self, _mock_names):
|
||||
"""Empty permissions (no filtering) returns every tool unchanged."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("web_fetch")]
|
||||
perms = CopilotPermissions()
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert result == tools
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_allowlist_keeps_only_matching(self, _mock_names):
|
||||
"""Explicit allowlist (tools_exclude=False) keeps only listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "web_fetch"
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_blacklist_excludes_listed(self, _mock_names):
|
||||
"""Blacklist (tools_exclude=True) removes only the listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "bash_exec" not in names
|
||||
assert "run_block" in names
|
||||
assert "web_fetch" in names
|
||||
assert len(result) == 2
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_unknown_tool_name_filtered_out(self, _mock_names):
|
||||
"""A tool whose name is not in all_known_tool_names is dropped."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("unknown_tool")]
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "unknown_tool" not in names
|
||||
assert names == ["run_block"]
|
||||
|
||||
|
||||
# ---- _prepare_baseline_attachments tests ---- #
|
||||
|
||||
|
||||
class TestPrepareBaselineAttachments:
|
||||
"""Tests for _prepare_baseline_attachments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file_ids(self):
|
||||
"""Empty file_ids returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp")
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_user_id(self):
|
||||
"""Empty user_id returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["file1"], "", "sess1", "/tmp"
|
||||
)
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_vision_blocks(self):
|
||||
"""A PNG image within size limits is returned as a base64 vision block."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "photo.png"
|
||||
fake_info.mime_type = "image/png"
|
||||
fake_info.size_bytes = 1024
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp/workdir"
|
||||
)
|
||||
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "image"
|
||||
assert blocks[0]["source"]["media_type"] == "image/png"
|
||||
assert blocks[0]["source"]["type"] == "base64"
|
||||
assert "photo.png" in hint
|
||||
assert "embedded as image" in hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_file_saved_to_working_dir(self, tmp_path):
|
||||
"""A non-image file is written to working_dir."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "data.csv"
|
||||
fake_info.mime_type = "text/csv"
|
||||
fake_info.size_bytes = 42
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert blocks == []
|
||||
assert "data.csv" in hint
|
||||
assert "saved to" in hint
|
||||
saved = tmp_path / "data.csv"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"col1,col2\na,b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found_skipped(self):
|
||||
"""When get_file_info returns None the file is silently skipped."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["missing_id"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_manager_error(self):
|
||||
"""When get_workspace_manager raises, returns empty results."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(side_effect=RuntimeError("connection failed")),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
|
||||
class TestBaselineCostExtraction:
|
||||
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_extracted_from_response_header(self):
|
||||
"""state.cost_usd is set from x-total-cost header when present."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
# Build a mock raw httpx response with the cost header
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {"x-total-cost": "0.0123"}
|
||||
|
||||
# Build a mock async streaming response that yields no chunks but has
|
||||
# a _response attribute pointing to the mock httpx response
|
||||
mock_stream_response = MagicMock()
|
||||
mock_stream_response._response = mock_raw_response
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream_response.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_stream_response
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.0123)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_accumulates_across_calls(self):
|
||||
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
def make_stream_mock(cost: str) -> MagicMock:
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": cost}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
return mock_stream
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "first"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "second"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.03)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_header_absent(self):
|
||||
"""state.cost_usd remains None when response has no x-total-cost header."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_extracted_even_when_stream_raises(self):
|
||||
"""cost_usd is captured in the finally block even when streaming fails."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.005"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def failing_aiter():
|
||||
raise RuntimeError("stream error")
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream.__aiter__ = lambda self: failing_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
pytest.raises(RuntimeError, match="stream error"),
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.005)
|
||||
@@ -0,0 +1,667 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
|
||||
import json as stdlib_json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
def _make_transcript_content(*roles: str) -> str:
|
||||
"""Build a minimal valid JSONL transcript from role names."""
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
uid = f"uuid-{i}"
|
||||
entry: dict = {
|
||||
"type": role,
|
||||
"uuid": uid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": f"{role} message {i}"}],
|
||||
},
|
||||
}
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_differ(self):
|
||||
"""Sanity: the two tiers are actually distinct in production config."""
|
||||
assert config.model != config.fast_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=0,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
class TestUploadFinalTranscript:
|
||||
"""``_upload_final_transcript`` serialises and calls storage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uploads_valid_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
call_kwargs = upload_mock.await_args.kwargs
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=0,
|
||||
)
|
||||
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_upload_exceptions(self):
|
||||
"""Upload failures should not propagate (flow continues for the user)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordTurnToTranscript:
|
||||
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
|
||||
|
||||
def test_records_final_assistant_text(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text="hello there",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "hello there" in jsonl
|
||||
assert STOP_REASON_END_TURN in jsonl
|
||||
|
||||
def test_records_tool_use_then_tool_result(self):
|
||||
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="use a tool")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
|
||||
],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# user, assistant(tool_use), user(tool_result) = 3 entries
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert STOP_REASON_TOOL_USE in jsonl
|
||||
assert "tool_use" in jsonl
|
||||
assert "tool_result" in jsonl
|
||||
assert "call-1" in jsonl
|
||||
|
||||
def test_records_nothing_on_empty_response(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_malformed_tool_args_dont_crash(self):
|
||||
"""Bad JSON in tool arguments falls back to {} without raising."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""End-to-end: load prior → append new turn → upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New user turn.
|
||||
builder.append_user(content="new question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# New assistant turn.
|
||||
response = LLMLoopResponse(
|
||||
response_text="new answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
"""Backfill only runs when the last entry is not already assistant."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
# Simulate the backfill guard from stream_chat_completion_baseline.
|
||||
assistant_text = "partial text before error"
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "partial text before error" in builder.to_jsonl()
|
||||
|
||||
# Second invocation: the guard must prevent double-append.
|
||||
initial_count = builder.entry_count
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "duplicate"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
def test_upload_allowed_for_user_with_coverage(self):
|
||||
assert should_upload_transcript("user-1", True) is True
|
||||
|
||||
def test_upload_skipped_when_no_user(self):
|
||||
assert should_upload_transcript(None, True) is False
|
||||
|
||||
def test_upload_skipped_when_empty_user(self):
|
||||
assert should_upload_transcript("", True) is False
|
||||
|
||||
def test_upload_skipped_without_coverage(self):
|
||||
"""Partial transcript must never clobber a more complete stored one."""
|
||||
assert should_upload_transcript("user-1", False) is False
|
||||
|
||||
def test_upload_skipped_when_no_user_and_no_coverage(self):
|
||||
assert should_upload_transcript(None, False) is False
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
driving each step through the real helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
|
||||
# --- 2. Append a new user turn + a new assistant response ---
|
||||
builder.append_user(content="follow-up question")
|
||||
_record_turn_to_transcript(
|
||||
LLMLoopResponse(
|
||||
response_text="follow-up answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
),
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
"""Anonymous (user_id=None) → upload gate must return False."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
@@ -8,18 +8,35 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
# Per-request routing mode for a single chat turn.
|
||||
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
|
||||
# - 'extended_thinking': route to the Claude Agent SDK path with the default
|
||||
# (opus) model.
|
||||
# ``None`` means "no override"; the server falls back to the Claude Code
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
@@ -77,11 +94,11 @@ class ChatConfig(BaseSettings):
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
|
||||
@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
|
||||
or ``tool-outputs/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
# the project dir and the tool directory.
|
||||
# Accept both "tool-results" (SDK's persisted outputs) and
|
||||
# "tool-outputs" (the model sometimes confuses workspace paths
|
||||
# with filesystem paths and generates this variant).
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
and parts[1] in ("tool-results", "tool-outputs")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_outputs_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
|
||||
@@ -14,6 +14,7 @@ from prisma.types import (
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
@@ -23,12 +24,22 @@ from .model import (
|
||||
ChatSession,
|
||||
ChatSessionInfo,
|
||||
ChatSessionMetadata,
|
||||
invalidate_session_cache,
|
||||
cache_chat_session,
|
||||
)
|
||||
from .model import get_chat_session as get_chat_session_cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaginatedMessages(BaseModel):
|
||||
"""Result of a paginated message query."""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
has_more: bool
|
||||
oldest_sequence: int | None
|
||||
session: ChatSessionInfo
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
@@ -38,6 +49,116 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def get_chat_session_metadata(session_id: str) -> ChatSessionInfo | None:
|
||||
"""Get chat session metadata (without messages) for ownership validation."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
)
|
||||
return ChatSessionInfo.from_db(session) if session else None
|
||||
|
||||
|
||||
async def get_chat_messages_paginated(
|
||||
session_id: str,
|
||||
limit: int = 50,
|
||||
before_sequence: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PaginatedMessages | None:
|
||||
"""Get paginated messages for a session, newest first.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided)
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
limit: Max messages to return.
|
||||
before_sequence: Cursor — return messages with sequence < this value.
|
||||
user_id: If provided, filters via ``Session.userId`` so only the
|
||||
session owner's messages are returned (acts as an ownership guard).
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
session_where["userId"] = user_id
|
||||
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
msg_include: dict[str, Any] = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
if before_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"lt": before_sequence}}
|
||||
|
||||
# Single query: session existence/ownership + paginated messages
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where=session_where,
|
||||
include={"Messages": msg_include},
|
||||
)
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
session_info = ChatSessionInfo.from_db(session)
|
||||
results = list(session.Messages) if session.Messages else []
|
||||
|
||||
has_more = len(results) > limit
|
||||
results = results[:limit]
|
||||
|
||||
# Reverse to ascending order
|
||||
results.reverse()
|
||||
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
|
||||
return PaginatedMessages(
|
||||
messages=messages,
|
||||
has_more=has_more,
|
||||
oldest_sequence=oldest_sequence,
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
@@ -380,8 +501,11 @@ async def update_tool_message_content(
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
Updates the Redis cache in-place instead of invalidating it.
|
||||
Invalidation would delete the key, creating a window where concurrent
|
||||
``get_chat_session`` calls re-populate the cache from DB — potentially
|
||||
with stale data if the DB write from the previous turn hasn't propagated.
|
||||
This race caused duplicate user messages on the next turn.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
@@ -392,5 +516,13 @@ async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
# Update cache in-place rather than invalidating to avoid a
|
||||
# race window where the empty cache gets re-populated with
|
||||
# stale data by a concurrent get_chat_session call.
|
||||
session = await get_chat_session_cached(session_id)
|
||||
if session and session.messages:
|
||||
for msg in reversed(session.messages):
|
||||
if msg.role == "assistant":
|
||||
msg.duration_ms = duration_ms
|
||||
break
|
||||
await cache_chat_session(session)
|
||||
|
||||
388
autogpt_platform/backend/backend/copilot/db_test.py
Normal file
388
autogpt_platform/backend/backend/copilot/db_test.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Unit tests for copilot.db — paginated message queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
|
||||
from backend.copilot.db import (
|
||||
PaginatedMessages,
|
||||
get_chat_messages_paginated,
|
||||
set_turn_duration,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage as CopilotChatMessage
|
||||
from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session
|
||||
|
||||
|
||||
def _make_msg(
|
||||
sequence: int,
|
||||
role: str = "assistant",
|
||||
content: str | None = "hello",
|
||||
tool_calls: Any = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Build a minimal PrismaChatMessage for testing."""
|
||||
return PrismaChatMessage(
|
||||
id=f"msg-{sequence}",
|
||||
createdAt=datetime.now(UTC),
|
||||
sessionId="sess-1",
|
||||
role=role,
|
||||
content=content,
|
||||
sequence=sequence,
|
||||
toolCalls=tool_calls,
|
||||
name=None,
|
||||
toolCallId=None,
|
||||
refusal=None,
|
||||
functionCall=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_session(
|
||||
session_id: str = "sess-1",
|
||||
user_id: str = "user-1",
|
||||
messages: list[PrismaChatMessage] | None = None,
|
||||
) -> PrismaChatSession:
|
||||
"""Build a minimal PrismaChatSession for testing."""
|
||||
now = datetime.now(UTC)
|
||||
session = PrismaChatSession.model_construct(
|
||||
id=session_id,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
userId=user_id,
|
||||
credentials={},
|
||||
successfulAgentRuns={},
|
||||
successfulAgentSchedules={},
|
||||
totalPromptTokens=0,
|
||||
totalCompletionTokens=0,
|
||||
title=None,
|
||||
metadata={},
|
||||
Messages=messages or [],
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
SESSION_ID = "sess-1"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_db():
|
||||
"""Patch ChatSession.prisma().find_first and ChatMessage.prisma().find_many.
|
||||
|
||||
find_first is used for the main query (session + included messages).
|
||||
find_many is used only for boundary expansion queries.
|
||||
"""
|
||||
with (
|
||||
patch.object(PrismaChatSession, "prisma") as mock_session_prisma,
|
||||
patch.object(PrismaChatMessage, "prisma") as mock_msg_prisma,
|
||||
):
|
||||
find_first = AsyncMock()
|
||||
mock_session_prisma.return_value.find_first = find_first
|
||||
|
||||
find_many = AsyncMock(return_value=[])
|
||||
mock_msg_prisma.return_value.find_many = find_many
|
||||
|
||||
yield find_first, find_many
|
||||
|
||||
|
||||
# ---------- Basic pagination ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_page_returns_messages_ascending(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Messages are returned in ascending sequence order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert isinstance(page, PaginatedMessages)
|
||||
assert [m.sequence for m in page.messages] == [1, 2, 3]
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_more_when_results_exceed_limit(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""has_more is True when DB returns more than limit items."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
assert len(page.messages) == 2
|
||||
assert [m.sequence for m in page.messages] == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_returns_no_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[])
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.messages == []
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_sequence_filters_correctly(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""before_sequence is passed as a where filter inside the Messages include."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, before_sequence=5)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["where"] == {"sequence": {"lt": 5}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_where_on_messages_without_before_sequence(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Without before_sequence, the Messages include has no where clause."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_filter_applied_to_session_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""user_id adds a userId filter to the session-level where clause."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, user_id="user-abc")
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
where = call_kwargs.kwargs.get("where") or call_kwargs[1].get("where")
|
||||
assert where["userId"] == "user-abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_found_returns_none(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Returns None when session doesn't exist or user doesn't own it."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = None
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_info_included_in_result(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""PaginatedMessages includes session metadata."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.session.session_id == SESSION_ID
|
||||
|
||||
|
||||
# ---------- Backward boundary expansion ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_includes_assistant(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When page starts with a tool message, expand backward to include
|
||||
the owning assistant message."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(3, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [3, 4, 5]
|
||||
assert page.messages[0].role == "assistant"
|
||||
assert page.oldest_sequence == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_includes_multiple_tool_msgs(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Boundary expansion scans past consecutive tool messages to find
|
||||
the owning assistant."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(7, role="tool")],
|
||||
)
|
||||
find_many.return_value = [
|
||||
_make_msg(6, role="tool"),
|
||||
_make_msg(5, role="tool"),
|
||||
_make_msg(4, role="assistant"),
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [4, 5, 6, 7]
|
||||
assert page.messages[0].role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_sets_has_more_when_not_at_start(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""After boundary expansion, has_more=True if expanded msgs aren't at seq 0."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(2, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_no_has_more_at_conversation_start(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""has_more stays False when boundary expansion reaches seq 0."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(1, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(0, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_boundary_expansion_when_first_msg_not_tool(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""No boundary expansion when the first message is not a tool message."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3, role="user"), _make_msg(2, role="assistant")],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert find_many.call_count == 0
|
||||
assert [m.sequence for m in page.messages] == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_warns_when_no_owner_found(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When boundary scan doesn't find a non-tool message, a warning is logged
|
||||
and the boundary messages are still included."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(10, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(i, role="tool") for i in range(9, -1, -1)]
|
||||
|
||||
with patch("backend.copilot.db.logger") as mock_logger:
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "tool"
|
||||
assert len(page.messages) > 1
|
||||
|
||||
|
||||
# ---------- Turn duration (integration tests) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_user_id):
|
||||
"""set_turn_duration patches the cached session without invalidation.
|
||||
|
||||
Verifies that after calling set_turn_duration the Redis-cached session
|
||||
reflects the updated durationMs on the last assistant message, without
|
||||
the cache having been deleted and re-populated (which could race with
|
||||
concurrent get_chat_session calls).
|
||||
"""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
CopilotChatMessage(role="user", content="hello"),
|
||||
CopilotChatMessage(role="assistant", content="hi there"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Ensure the session is in cache
|
||||
cached = await get_chat_session(session.session_id, test_user_id)
|
||||
assert cached is not None
|
||||
assert cached.messages[-1].duration_ms is None
|
||||
|
||||
# Update turn duration — should patch cache in-place
|
||||
await set_turn_duration(session.session_id, 1234)
|
||||
|
||||
# Read from cache (not DB) — the cache should already have the update
|
||||
updated = await get_chat_session(session.session_id, test_user_id)
|
||||
assert updated is not None
|
||||
assistant_msgs = [m for m in updated.messages if m.role == "assistant"]
|
||||
assert len(assistant_msgs) == 1
|
||||
assert assistant_msgs[0].duration_ms == 1234
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user_id):
|
||||
"""set_turn_duration is a no-op when there are no assistant messages."""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
CopilotChatMessage(role="user", content="hello"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Should not raise
|
||||
await set_turn_duration(session.session_id, 5678)
|
||||
|
||||
cached = await get_chat_session(session.session_id, test_user_id)
|
||||
assert cached is not None
|
||||
# User message should not have durationMs
|
||||
assert cached.messages[0].duration_ms is None
|
||||
@@ -13,7 +13,7 @@ import time
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@@ -30,6 +30,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
async def resolve_effective_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
) -> CopilotMode | None:
|
||||
"""Strip ``mode`` when the user is not entitled to the toggle.
|
||||
|
||||
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
|
||||
processor enforces the same gate server-side so an authenticated
|
||||
user cannot bypass the flag by crafting a request directly.
|
||||
"""
|
||||
if mode is None:
|
||||
return None
|
||||
allowed = await is_feature_enabled(
|
||||
Flag.CHAT_MODE_OPTION,
|
||||
user_id or "anonymous",
|
||||
default=False,
|
||||
)
|
||||
if not allowed:
|
||||
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
|
||||
return None
|
||||
return mode
|
||||
|
||||
|
||||
async def resolve_use_sdk_for_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
*,
|
||||
use_claude_code_subscription: bool,
|
||||
config_default: bool,
|
||||
) -> bool:
|
||||
"""Pick the SDK vs baseline path for a single turn.
|
||||
|
||||
Per-request ``mode`` wins whenever it is set (after the
|
||||
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
|
||||
falls back to the Claude Code subscription override, then the
|
||||
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return False
|
||||
if mode == "extended_thinking":
|
||||
return True
|
||||
return use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config_default,
|
||||
)
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
@@ -100,8 +151,8 @@ class CoPilotProcessor:
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
DB operations route through DatabaseManagerAsyncClient (RPC) via the
|
||||
db_accessors pattern — no direct Prisma connection is needed here.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
@@ -250,21 +301,26 @@ class CoPilotProcessor:
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
effective_mode = None
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
# Enforce server-side feature-flag gate so unauthorised
|
||||
# users cannot force a mode by crafting the request.
|
||||
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
|
||||
use_sdk = await resolve_use_sdk_for_mode(
|
||||
effective_mode,
|
||||
entry.user_id,
|
||||
use_claude_code_subscription=config.use_claude_code_subscription,
|
||||
config_default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={effective_mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
@@ -276,6 +332,7 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Unit tests for CoPilot mode routing logic in the processor.
|
||||
|
||||
Tests cover the mode→service mapping:
|
||||
- 'fast' → baseline service
|
||||
- 'extended_thinking' → SDK service
|
||||
- None → feature flag / config fallback
|
||||
|
||||
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
|
||||
the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import (
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveUseSdkForMode:
|
||||
"""Tests for the per-request mode routing logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_mode_uses_baseline(self):
|
||||
"""mode='fast' always routes to baseline, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"fast",
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=True,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_thinking_uses_sdk(self):
|
||||
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
"extended_thinking",
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_subscription_override(self):
|
||||
"""mode=None with claude_code_subscription=True routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_feature_flag(self):
|
||||
"""mode=None with feature flag enabled routes to SDK."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
) as flag_mock:
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_config_default(self):
|
||||
"""mode=None falls back to config.use_claude_agent_sdk."""
|
||||
# When LaunchDarkly returns the default (True), we expect SDK routing.
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=True,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_all_disabled(self):
|
||||
"""mode=None with all flags off routes to baseline."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestResolveEffectiveMode:
|
||||
"""Tests for the CHAT_MODE_OPTION server-side gate."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_passes_through(self):
|
||||
"""mode=None is returned as-is without a flag check."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode(None, "user-1") is None
|
||||
flag_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_stripped_when_flag_disabled(self):
|
||||
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") is None
|
||||
assert await resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_preserved_when_flag_enabled(self):
|
||||
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert (
|
||||
await resolve_effective_mode("extended_thinking", "user-1")
|
||||
== "extended_thinking"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_user_with_mode(self):
|
||||
"""Anonymous users (user_id=None) still pass through the gate."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
123
autogpt_platform/backend/backend/copilot/executor/utils_test.py
Normal file
123
autogpt_platform/backend/backend/copilot/executor/utils_test.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
|
||||
|
||||
from backend.copilot.executor.utils import (
|
||||
COPILOT_EXECUTION_EXCHANGE,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_ROUTING_KEY,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
CoPilotLogMetadata,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
|
||||
class TestCoPilotExecutionEntry:
|
||||
def test_basic_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
)
|
||||
assert entry.session_id == "s1"
|
||||
assert entry.user_id == "u1"
|
||||
assert entry.message == "hello"
|
||||
assert entry.is_user_message is True
|
||||
assert entry.mode is None
|
||||
assert entry.context is None
|
||||
assert entry.file_ids is None
|
||||
|
||||
def test_mode_field(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="fast",
|
||||
)
|
||||
assert entry.mode == "fast"
|
||||
|
||||
entry2 = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
mode="extended_thinking",
|
||||
)
|
||||
assert entry2.mode == "extended_thinking"
|
||||
|
||||
def test_optional_fields(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="test",
|
||||
turn_id="t1",
|
||||
context={"url": "https://example.com"},
|
||||
file_ids=["f1", "f2"],
|
||||
is_user_message=False,
|
||||
)
|
||||
assert entry.turn_id == "t1"
|
||||
assert entry.context == {"url": "https://example.com"}
|
||||
assert entry.file_ids == ["f1", "f2"]
|
||||
assert entry.is_user_message is False
|
||||
|
||||
def test_serialization_roundtrip(self):
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id="s1",
|
||||
user_id="u1",
|
||||
message="hello",
|
||||
mode="fast",
|
||||
)
|
||||
json_str = entry.model_dump_json()
|
||||
restored = CoPilotExecutionEntry.model_validate_json(json_str)
|
||||
assert restored == entry
|
||||
|
||||
|
||||
class TestCancelCoPilotEvent:
|
||||
def test_basic(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
assert event.session_id == "s1"
|
||||
|
||||
def test_serialization(self):
|
||||
event = CancelCoPilotEvent(session_id="s1")
|
||||
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
|
||||
assert restored.session_id == "s1"
|
||||
|
||||
|
||||
class TestCreateCopilotQueueConfig:
|
||||
def test_returns_valid_config(self):
|
||||
config = create_copilot_queue_config()
|
||||
assert len(config.exchanges) == 2
|
||||
assert len(config.queues) == 2
|
||||
|
||||
def test_execution_queue_properties(self):
|
||||
config = create_copilot_queue_config()
|
||||
exec_queue = next(
|
||||
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert exec_queue.durable is True
|
||||
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
|
||||
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
|
||||
|
||||
def test_cancel_queue_uses_fanout(self):
|
||||
config = create_copilot_queue_config()
|
||||
cancel_queue = next(
|
||||
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
|
||||
)
|
||||
assert cancel_queue.exchange is not None
|
||||
assert cancel_queue.exchange.type.value == "fanout"
|
||||
|
||||
|
||||
class TestCoPilotLogMetadata:
|
||||
def test_creates_logger_with_metadata(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
|
||||
assert log is not None
|
||||
|
||||
def test_filters_none_values(self):
|
||||
import logging
|
||||
|
||||
base_logger = logging.getLogger("test")
|
||||
log = CoPilotLogMetadata(
|
||||
base_logger, session_id="s1", user_id=None, turn_id="t1"
|
||||
)
|
||||
assert log is not None
|
||||
@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
# GitHub user identity caches (keyed by user_id only, not provider tuple).
|
||||
# Declared here so invalidate_user_provider_cache() can reference them.
|
||||
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
|
||||
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
|
||||
)
|
||||
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
|
||||
For GitHub specifically, also clears the git-identity caches so that
|
||||
``get_github_user_git_identity()`` re-fetches the user's profile on
|
||||
the next call instead of serving stale identity data.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
if provider == "github":
|
||||
_gh_identity_cache.pop(user_id, None)
|
||||
_gh_identity_null_cache.pop(user_id, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GitHub user identity (for git committer env vars)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
|
||||
"""Fetch the GitHub user's name and email for git committer env vars.
|
||||
|
||||
Uses the ``/user`` GitHub API endpoint with the user's stored token.
|
||||
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
|
||||
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
|
||||
connected GitHub account. Returns ``None`` otherwise.
|
||||
|
||||
Results are cached for 10 minutes; "not connected" results are cached for
|
||||
60 s (same as null-token cache).
|
||||
"""
|
||||
if user_id in _gh_identity_null_cache:
|
||||
return None
|
||||
if cached := _gh_identity_cache.get(user_id):
|
||||
return cached
|
||||
|
||||
token = await get_provider_token(user_id, "github")
|
||||
if not token:
|
||||
_gh_identity_null_cache[user_id] = True
|
||||
return None
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(
|
||||
"[git-identity] GitHub /user returned %s for user %s",
|
||||
resp.status,
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
data = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
name = data.get("name") or data.get("login") or "AutoGPT User"
|
||||
# GitHub may return email=null if the user has set their email to private.
|
||||
# Fall back to the noreply address GitHub generates for every account.
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
gh_id = data.get("id", "")
|
||||
login = data.get("login", "user")
|
||||
email = f"{gh_id}+{login}@users.noreply.github.com"
|
||||
|
||||
identity = {
|
||||
"GIT_AUTHOR_NAME": name,
|
||||
"GIT_AUTHOR_EMAIL": email,
|
||||
"GIT_COMMITTER_NAME": name,
|
||||
"GIT_COMMITTER_EMAIL": email,
|
||||
}
|
||||
_gh_identity_cache[user_id] = identity
|
||||
return identity
|
||||
|
||||
@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_gh_identity_cache,
|
||||
_gh_identity_null_cache,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
@@ -49,9 +51,13 @@ def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
def test_clears_gh_identity_cache_for_github_provider(self):
|
||||
"""When provider is 'github', identity caches must also be cleared."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Old Name",
|
||||
"GIT_AUTHOR_EMAIL": "old@example.com",
|
||||
"GIT_COMMITTER_NAME": "Old Name",
|
||||
"GIT_COMMITTER_EMAIL": "old@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_cache
|
||||
|
||||
def test_clears_gh_identity_null_cache_for_github_provider(self):
|
||||
"""When provider is 'github', the identity null-cache must also be cleared."""
|
||||
_gh_identity_null_cache[_USER] = True
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_null_cache
|
||||
|
||||
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
|
||||
"""When provider is NOT 'github', identity caches must be left alone."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Some Name",
|
||||
"GIT_AUTHOR_EMAIL": "some@example.com",
|
||||
"GIT_COMMITTER_NAME": "Some Name",
|
||||
"GIT_COMMITTER_EMAIL": "some@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "some-other-provider")
|
||||
assert _USER in _gh_identity_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -129,8 +163,15 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -141,6 +182,8 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -176,6 +219,96 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -64,6 +64,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
sequence: int | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
@@ -77,10 +78,54 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
sequence=prisma_message.sequence,
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
def is_message_duplicate(
|
||||
messages: list[ChatMessage],
|
||||
role: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Check whether *content* is already present in the current pending turn.
|
||||
|
||||
Only inspects trailing messages that share the given *role* (i.e. the
|
||||
current turn). This ensures legitimately repeated messages across different
|
||||
turns are not suppressed, while same-turn duplicates from stale cache are
|
||||
still caught.
|
||||
"""
|
||||
for m in reversed(messages):
|
||||
if m.role == role:
|
||||
if m.content == content:
|
||||
return True
|
||||
else:
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def maybe_append_user_message(
|
||||
session: "ChatSession",
|
||||
message: str | None,
|
||||
is_user_message: bool,
|
||||
) -> bool:
|
||||
"""Append a user/assistant message to the session if not already present.
|
||||
|
||||
The route handler already persists the user message before enqueueing,
|
||||
so we check trailing same-role messages to avoid re-appending when the
|
||||
session cache is slightly stale.
|
||||
|
||||
Returns True if the message was appended, False if skipped.
|
||||
"""
|
||||
if not message:
|
||||
return False
|
||||
role = "user" if is_user_message else "assistant"
|
||||
if is_message_duplicate(session.messages, role, message):
|
||||
return False
|
||||
session.messages.append(ChatMessage(role=role, content=message))
|
||||
return True
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
@@ -17,6 +17,8 @@ from .model import (
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
@@ -424,3 +426,151 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# is_message_duplicate #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_duplicate_detected_in_trailing_same_role():
|
||||
"""Duplicate user message at the tail is detected."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi there"),
|
||||
ChatMessage(role="user", content="yes"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "yes") is True
|
||||
|
||||
|
||||
def test_duplicate_not_detected_across_turns():
|
||||
"""Same text in a previous turn (separated by assistant) is NOT a duplicate."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="yes"),
|
||||
ChatMessage(role="assistant", content="ok"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "yes") is False
|
||||
|
||||
|
||||
def test_no_duplicate_on_empty_messages():
|
||||
"""Empty message list never reports a duplicate."""
|
||||
assert is_message_duplicate([], "user", "hello") is False
|
||||
|
||||
|
||||
def test_no_duplicate_when_content_differs():
|
||||
"""Different content in the trailing same-role block is not a duplicate."""
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content="response"),
|
||||
ChatMessage(role="user", content="first message"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "second message") is False
|
||||
|
||||
|
||||
def test_duplicate_with_multiple_trailing_same_role():
|
||||
"""Detects duplicate among multiple consecutive same-role messages."""
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content="response"),
|
||||
ChatMessage(role="user", content="msg1"),
|
||||
ChatMessage(role="user", content="msg2"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "msg1") is True
|
||||
assert is_message_duplicate(msgs, "user", "msg2") is True
|
||||
assert is_message_duplicate(msgs, "user", "msg3") is False
|
||||
|
||||
|
||||
def test_duplicate_check_for_assistant_role():
|
||||
"""Works correctly when checking assistant role too."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="assistant", content="how can I help?"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "assistant", "hello") is True
|
||||
assert is_message_duplicate(msgs, "assistant", "new response") is False
|
||||
|
||||
|
||||
def test_no_false_positive_when_content_is_none():
|
||||
"""Messages with content=None in the trailing block do not match."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content=None),
|
||||
ChatMessage(role="user", content="hello"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "hello") is True
|
||||
# None-content message should not match any string
|
||||
msgs2 = [
|
||||
ChatMessage(role="user", content=None),
|
||||
]
|
||||
assert is_message_duplicate(msgs2, "user", "hello") is False
|
||||
|
||||
|
||||
def test_all_same_role_messages():
|
||||
"""When all messages share the same role, the entire list is scanned."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="first"),
|
||||
ChatMessage(role="user", content="second"),
|
||||
ChatMessage(role="user", content="third"),
|
||||
]
|
||||
assert is_message_duplicate(msgs, "user", "first") is True
|
||||
assert is_message_duplicate(msgs, "user", "new") is False
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# maybe_append_user_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_maybe_append_user_message_appends_new():
|
||||
"""A new user message is appended and returns True."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "new msg", is_user_message=True)
|
||||
assert result is True
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[-1].role == "user"
|
||||
assert session.messages[-1].content == "new msg"
|
||||
|
||||
|
||||
def test_maybe_append_user_message_skips_duplicate():
|
||||
"""A duplicate user message is skipped and returns False."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="user", content="dup"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=True)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
def test_maybe_append_user_message_none_message():
|
||||
"""None/empty message returns False without appending."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
assert maybe_append_user_message(session, None, is_user_message=True) is False
|
||||
assert maybe_append_user_message(session, "", is_user_message=True) is False
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
def test_maybe_append_assistant_message():
|
||||
"""Works for assistant role when is_user_message=False."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "response", is_user_message=False)
|
||||
assert result is True
|
||||
assert session.messages[-1].role == "assistant"
|
||||
assert session.messages[-1].content == "response"
|
||||
|
||||
|
||||
def test_maybe_append_assistant_skips_duplicate():
|
||||
"""Duplicate assistant message is skipped."""
|
||||
session = ChatSession.new(user_id="u", dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="dup"),
|
||||
]
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
@@ -102,6 +103,7 @@ ToolName = Literal[
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Agent",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
|
||||
@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
|
||||
@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
### Handling binary/image data in tool outputs — CRITICAL
|
||||
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
|
||||
1. **NEVER** try to inline or render the base64 content in your response.
|
||||
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
|
||||
3. **Show** the result via the workspace download URL in Markdown: ``.
|
||||
|
||||
### Passing large data between tools — CRITICAL
|
||||
When tool outputs produce large text that you need to feed into another tool:
|
||||
- **NEVER** copy-paste the full text into the next tool call argument.
|
||||
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
|
||||
- This avoids token limits and ensures data integrity.
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
@@ -114,6 +126,21 @@ After building the file, reference it with `@@agptfile:` in other tools:
|
||||
- When spawning sub-agents for research, ensure each has a distinct
|
||||
non-overlapping scope to avoid redundant searches.
|
||||
|
||||
|
||||
### Tool Discovery Priority
|
||||
|
||||
When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
|
||||
|
||||
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
|
||||
|
||||
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
|
||||
|
||||
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -138,6 +165,11 @@ parent autopilot handles orchestration.
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### SDK tool-result files in E2B
|
||||
When you `Read` an SDK tool-result file, it is automatically copied into the
|
||||
sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
@@ -203,19 +235,22 @@ def _build_storage_supplement(
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
|
||||
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
|
||||
To read these files, use `Read` — it reads from the host filesystem.
|
||||
|
||||
### Large tool outputs saved to workspace
|
||||
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
|
||||
full output is in workspace storage (NOT on the local filesystem). To access it:
|
||||
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
|
||||
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
@@ -6,16 +6,23 @@ from pathlib import Path
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_before_building(self):
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Clarifying Before Building" in content
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
# find_block must appear in the clarification section (before the workflow)
|
||||
clarify_section = content.split("Clarifying Before Building")[1].split(
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
|
||||
@@ -9,11 +9,15 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +25,40 @@ logger = logging.getLogger(__name__)
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscription tier definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tiers with increasing token allowances.
|
||||
|
||||
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
|
||||
Once ``prisma generate`` is run, this can be replaced with::
|
||||
|
||||
from prisma.enums import SubscriptionTier
|
||||
"""
|
||||
|
||||
FREE = "FREE"
|
||||
PRO = "PRO"
|
||||
BUSINESS = "BUSINESS"
|
||||
ENTERPRISE = "ENTERPRISE"
|
||||
|
||||
|
||||
# Multiplier applied to the base limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole token counts and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# the type and round the result in get_global_rate_limits().
|
||||
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.FREE: 1,
|
||||
SubscriptionTier.PRO: 5,
|
||||
SubscriptionTier.BUSINESS: 20,
|
||||
SubscriptionTier.ENTERPRISE: 60,
|
||||
}
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
@@ -36,6 +74,7 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
tier: SubscriptionTier = DEFAULT_TIER
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
@@ -66,6 +105,7 @@ async def get_usage_status(
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
tier: SubscriptionTier = DEFAULT_TIER,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -74,6 +114,7 @@ async def get_usage_status(
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
tier: The user's rate-limit tier (included in the response).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -103,6 +144,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
tier=tier,
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
@@ -161,8 +203,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
Returns False if Redis is unavailable so the caller can handle
|
||||
compensation (fail-closed for billed operations, unlike the read-only
|
||||
rate-limit checks which fail-open).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
@@ -342,20 +385,103 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except Exception:
|
||||
raise _UserNotFoundError(user_id)
|
||||
if user.subscription_tier:
|
||||
return SubscriptionTier(user.subscription_tier)
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
async def get_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Look up the user's rate-limit tier from the database.
|
||||
|
||||
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
|
||||
to avoid a DB round-trip on every rate-limit check.
|
||||
|
||||
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
|
||||
unreachable or returns an unrecognised value, so the next call retries
|
||||
the query instead of serving a stale fallback for up to 5 minutes.
|
||||
"""
|
||||
try:
|
||||
return await _fetch_user_tier(user_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
|
||||
user_id[:8],
|
||||
DEFAULT_TIER.value,
|
||||
exc,
|
||||
)
|
||||
return DEFAULT_TIER
|
||||
|
||||
|
||||
# Expose cache management on the public function so callers (including tests)
|
||||
# never need to reach into the private ``_fetch_user_tier``.
|
||||
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
"""
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, SubscriptionTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
tier multiplier so that higher tiers receive proportionally larger
|
||||
allowances.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
@@ -377,7 +503,15 @@ async def get_global_rate_limits(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
# Apply tier multiplier
|
||||
tier = await get_user_tier(user_id)
|
||||
multiplier = TIER_MULTIPLIERS.get(tier, 1)
|
||||
if multiplier != 1:
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
return daily, weekly, tier
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_rate_limits(
|
||||
daily: int = 2_500_000,
|
||||
weekly: int = 12_500_000,
|
||||
tier: SubscriptionTier = SubscriptionTier.PRO,
|
||||
):
|
||||
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
|
||||
return patch(
|
||||
f"{_MODULE}.get_global_rate_limits",
|
||||
AsyncMock(return_value=(daily, weekly, tier)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
@@ -70,6 +82,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(daily=0),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
@@ -83,6 +96,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -112,6 +126,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -141,6 +156,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -171,6 +187,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -208,6 +225,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -228,6 +246,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -245,6 +264,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -275,6 +295,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
|
||||
@@ -3,41 +3,62 @@
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Clarifying Before Building
|
||||
### Clarifying — Before or During Building
|
||||
|
||||
Before starting the workflow below, check whether the user's goal is
|
||||
**ambiguous** — missing the output format, delivery channel, data source,
|
||||
or trigger. If so:
|
||||
1. Call `find_block` with a query targeting the ambiguous dimension to
|
||||
discover what the platform actually supports.
|
||||
2. Ask the user **one concrete question** grounded in the discovered
|
||||
Use `ask_question` whenever the user's intent is ambiguous — whether
|
||||
that's before starting or midway through the workflow. Common moments:
|
||||
|
||||
- **Before building**: output format, delivery channel, data source, or
|
||||
trigger is unspecified.
|
||||
- **During block discovery**: multiple blocks could fit and the user
|
||||
should choose.
|
||||
- **During JSON generation**: a wiring decision depends on user
|
||||
preference.
|
||||
|
||||
Steps:
|
||||
1. Call `find_block` (or another discovery tool) to learn what the
|
||||
platform actually supports for the ambiguous dimension.
|
||||
2. Call `ask_question` with a concrete question listing the discovered
|
||||
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
|
||||
which should the agent use for delivery?").
|
||||
3. **Wait for the user's answer** before proceeding.
|
||||
3. **Wait for the user's answer** before continuing.
|
||||
|
||||
**Skip this** when the goal already specifies all dimensions (e.g.
|
||||
"scrape prices from Amazon and email me daily").
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
|
||||
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
|
||||
returns the full graph structure (nodes + links). **Never edit blindly** —
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
- Use block IDs from step 2 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
- When editing, apply targeted changes and preserve unchanged parts
|
||||
5. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
|
||||
`wait_for_result=120` to verify the agent works end-to-end.
|
||||
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
|
||||
found, call `edit_agent` to fix and dry-run again. Repeat until the
|
||||
simulation passes or the problems are clearly unfixable.
|
||||
See "REQUIRED: Dry-Run Verification Loop" section below for details.
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
@@ -89,8 +110,8 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
@@ -231,19 +252,62 @@ call in a loop until the task is complete:
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
After creating or editing an agent, you MUST dry-run it before telling the
|
||||
user the agent is ready. NEVER skip this step.
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
#### Step-by-step workflow
|
||||
|
||||
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
|
||||
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
|
||||
and realistic sample inputs that exercise every path in the agent. This
|
||||
simulates execution using an LLM for each block — no real API calls,
|
||||
credentials, or credits are consumed.
|
||||
3. **Inspect output**: Examine the dry-run result for problems. If
|
||||
`wait_for_result` returns only a summary, call
|
||||
`view_agent_output(execution_id=..., show_execution_details=True)` to
|
||||
see the full node-by-node execution trace. Look for:
|
||||
- **Errors / failed nodes** — a node raised an exception or returned an
|
||||
error status. Common causes: wrong `source_name`/`sink_name` in links,
|
||||
missing `input_default` values, or referencing a nonexistent block output.
|
||||
- **Null / empty outputs** — data did not flow through a link. Verify that
|
||||
`source_name` and `sink_name` match the block schemas exactly (case-
|
||||
sensitive, including nested `_#_` notation).
|
||||
- **Nodes that never executed** — the node was not reached. Likely a
|
||||
missing or broken link from an upstream node.
|
||||
- **Unexpected values** — data arrived but in the wrong type or
|
||||
structure. Check type compatibility between linked ports.
|
||||
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
|
||||
agent JSON, then go back to step 2.
|
||||
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
|
||||
or the problems are clearly unfixable. If you stop making progress,
|
||||
report the remaining issues to the user and ask for guidance.
|
||||
|
||||
#### Good vs bad dry-run output
|
||||
|
||||
**Good output** (agent is ready):
|
||||
- All nodes executed successfully (no errors in the execution trace)
|
||||
- Data flows through every link with non-null, correctly-typed values
|
||||
- The final `AgentOutputBlock` contains a meaningful result
|
||||
- Status is `COMPLETED`
|
||||
|
||||
**Bad output** (needs fixing):
|
||||
- Status is `FAILED` — check the error message for the failing node
|
||||
- An output node received `null` — trace back to find the broken link
|
||||
- A node received data in the wrong format (e.g. string where list expected)
|
||||
- Nodes downstream of a failing node were skipped entirely
|
||||
|
||||
**Special block behaviour in dry-run mode:**
|
||||
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
|
||||
orchestrator can make LLM calls and agent executors can spawn child graphs.
|
||||
Their downstream tool blocks and child-graph blocks are still simulated.
|
||||
Note: real LLM inference calls are made (consuming API quota), even though
|
||||
platform credits are not charged. Agent-mode iterations are capped at 1 in
|
||||
dry-run to keep it fast.
|
||||
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
|
||||
so the LLM can produce a realistic mock response without connecting to the
|
||||
MCP server.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
|
||||
async def _server_noop() -> None:
|
||||
"""No-op server stub — SDK tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
|
||||
)
|
||||
async def _graph_cleanup_noop() -> AsyncIterator[None]:
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
|
||||
@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -28,6 +31,12 @@ from backend.copilot.context import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default number of lines returned by ``read_file`` when the caller does not
|
||||
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
|
||||
# decide whether the model is requesting the full file (and thus whether the
|
||||
# bridge copy is worthwhile).
|
||||
_DEFAULT_READ_LIMIT = 2000
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
|
||||
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
|
||||
are handled correctly: text is encoded to bytes for the base64 shell
|
||||
pipe, and raw bytes are passed through without any encoding.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
raw = content.encode() if isinstance(content, str) else content
|
||||
encoded = base64.b64encode(raw).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
|
||||
# stay on the host. When E2B is active, also copy the file into the
|
||||
# sandbox so bash_exec can access it for further processing.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
result = _read_local(file_path, offset, limit)
|
||||
if not result.get("isError"):
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, file_path, offset, limit
|
||||
)
|
||||
if annotation:
|
||||
result["content"][0]["text"] += annotation
|
||||
return result
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Bridging: copy SDK-internal files into E2B sandbox
|
||||
|
||||
# Files larger than this are written to /home/user/ via sandbox.files.write()
|
||||
# instead of /tmp/ via shell base64, to avoid shell argument length limits
|
||||
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
|
||||
# well under the typical Linux ARG_MAX (128 KB).
|
||||
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
|
||||
# Files larger than this are skipped entirely to avoid excessive transfer times.
|
||||
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
async def bridge_to_sandbox(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
|
||||
|
||||
When the model reads an SDK-internal file (e.g. tool-results), it often
|
||||
wants to process the data with bash. Copying the file into the sandbox
|
||||
under a stable name lets ``bash_exec`` access it without extra steps.
|
||||
|
||||
Only copies when offset=0 and limit is large enough to indicate the model
|
||||
wants the full file. Errors are logged but never propagated.
|
||||
|
||||
Returns the sandbox path on success, or ``None`` on skip/failure.
|
||||
|
||||
Size handling:
|
||||
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
|
||||
(``_sandbox_write``). Kept small to stay within ARG_MAX.
|
||||
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
|
||||
``sandbox.files.write()`` to avoid shell argument length limits.
|
||||
- > 50 MB: skipped entirely with a warning.
|
||||
|
||||
The sandbox filename is prefixed with a short hash of the full source
|
||||
path to avoid collisions when different source files share the same
|
||||
basename (e.g. multiple ``result.json`` files).
|
||||
"""
|
||||
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
|
||||
return None
|
||||
try:
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
unique_name = f"{source_id}-{basename}"
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _BRIDGE_SKIP_BYTES:
|
||||
logger.warning(
|
||||
"[E2B] Skipping bridge for large file (%d bytes): %s",
|
||||
file_size,
|
||||
basename,
|
||||
)
|
||||
return None
|
||||
|
||||
def _read_bytes() -> bytes:
|
||||
with open(expanded, "rb") as fh:
|
||||
return fh.read()
|
||||
|
||||
raw_content = await asyncio.to_thread(_read_bytes)
|
||||
try:
|
||||
text_content: str | None = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text_content = None
|
||||
data: str | bytes = text_content if text_content is not None else raw_content
|
||||
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
|
||||
sandbox_path = f"/tmp/{unique_name}"
|
||||
await _sandbox_write(sandbox, sandbox_path, data)
|
||||
else:
|
||||
sandbox_path = f"/home/user/{unique_name}"
|
||||
await sandbox.files.write(sandbox_path, data)
|
||||
logger.info(
|
||||
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
|
||||
)
|
||||
return sandbox_path
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to bridge SDK file to sandbox: %s",
|
||||
file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def bridge_and_annotate(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
|
||||
|
||||
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
|
||||
callers don't need to duplicate the pattern. Returns a string like
|
||||
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
|
||||
``None`` if bridging was skipped or failed.
|
||||
"""
|
||||
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
|
||||
if sandbox_path is None:
|
||||
return None
|
||||
return f"\n[Sandbox copy available at {sandbox_path}]"
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
@@ -13,12 +14,26 @@ import pytest
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_BRIDGE_SHELL_MAX_BYTES,
|
||||
_BRIDGE_SKIP_BYTES,
|
||||
_DEFAULT_READ_LIMIT,
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
bridge_and_annotate,
|
||||
bridge_to_sandbox,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
|
||||
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
|
||||
"""Compute the expected sandbox path for a bridged file."""
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
return f"{prefix}/{source_id}-{basename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
|
||||
# (via is_allowed_local_path without sdk_cwd). Regular files live on
|
||||
# the sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -119,7 +134,7 @@ class TestReadLocal:
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
@@ -127,6 +142,25 @@ class TestReadLocal:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_tool_outputs_file(self):
|
||||
"""Reading a tool-outputs file should also succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read-outputs"
|
||||
tool_outputs_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
|
||||
)
|
||||
os.makedirs(tool_outputs_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
|
||||
with open(filepath, "w") as f:
|
||||
f.write('{"data": "test"}\n')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "test" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
@@ -335,3 +369,199 @@ class TestSandboxWrite:
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bridge_sandbox() -> SimpleNamespace:
|
||||
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
return SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
|
||||
class TestBridgeToSandbox:
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_small_file(self, tmp_path):
|
||||
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
|
||||
f = tmp_path / "result.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "result.json" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_offset_nonzero(self, tmp_path):
|
||||
"""Bridging is skipped when offset != 0 (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_limit_too_small(self, tmp_path):
|
||||
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_file_does_not_raise(self, tmp_path):
|
||||
"""Bridging a non-existent file logs but does not propagate errors."""
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(
|
||||
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_write_failure_returns_none(self, tmp_path):
|
||||
"""If sandbox write fails, returns None (best-effort)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_file_uses_files_api(self, tmp_path):
|
||||
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
|
||||
f = tmp_path / "big.json"
|
||||
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_binary_file_preserves_bytes(self, tmp_path):
|
||||
"""A small binary file is bridged to /tmp via base64 without corruption."""
|
||||
binary_data = bytes(range(256))
|
||||
f = tmp_path / "image.png"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "base64" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
|
||||
"""A large binary file is bridged to /home/user/ as raw bytes."""
|
||||
binary_data = bytes(range(256)) * 200
|
||||
f = tmp_path / "photo.jpg"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
assert call_args[1] == binary_data
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_file_skipped(self, tmp_path):
|
||||
"""Files > 50 MB are skipped entirely."""
|
||||
f = tmp_path / "huge.bin"
|
||||
# Create a sparse file to avoid actually writing 50 MB
|
||||
with open(f, "wb") as fh:
|
||||
fh.seek(_BRIDGE_SKIP_BYTES + 1)
|
||||
fh.write(b"\0")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBridgeAndAnnotate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_annotation_on_success(self, tmp_path):
|
||||
"""On success, returns a newline-prefixed annotation with the sandbox path."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected_path = _expected_bridge_path(str(f))
|
||||
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_skipped(self, tmp_path):
|
||||
"""When bridging is skipped (e.g. offset != 0), returns None."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert annotation is None
|
||||
|
||||
@@ -20,6 +20,7 @@ config = ChatConfig()
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
@@ -29,25 +30,35 @@ def build_sdk_env(
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
|
||||
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
|
||||
the CLI writes temp/sub-agent output inside the per-session workspace
|
||||
directory rather than an inaccessible system temp path.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
env = {}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
@@ -65,4 +76,7 @@ def build_sdk_env(
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
|
||||
return env
|
||||
|
||||
@@ -240,3 +240,54 @@ class TestBuildSdkEnvModePriority:
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLAUDE_CODE_TMPDIR integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaudeCodeTmpdir:
|
||||
"""Verify build_sdk_env() sets CLAUDE_CODE_TMPDIR from *sdk_cwd*."""
|
||||
|
||||
def test_tmpdir_set_when_sdk_cwd_is_truthy(self):
|
||||
"""CLAUDE_CODE_TMPDIR is set to sdk_cwd when sdk_cwd is truthy."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/copilot-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/copilot-workspace"
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_none(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is None."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd=None)
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_empty_string(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is empty string."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="")
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_tmpdir_set_in_subscription_mode(self, mock_validate):
|
||||
"""CLAUDE_CODE_TMPDIR is set even in subscription mode."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/sub-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/sub-workspace"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
|
||||
@@ -28,13 +28,12 @@ Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
Always follow the **Tool Discovery Priority** described in the tool notes:
|
||||
call `find_block` before resorting to `run_mcp_tool`.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
- You searched `find_block` first and found no matching blocks, AND
|
||||
- The service is in the known hosted MCP servers list above or found via the registry API
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
|
||||
@@ -8,20 +8,19 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
@@ -403,7 +402,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -438,7 +437,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -462,7 +461,7 @@ class TestCompactTranscript:
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
@@ -568,11 +567,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
@@ -602,11 +601,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
|
||||
@@ -29,6 +29,7 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .compaction import compaction_events
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
@@ -689,3 +690,102 @@ def test_already_resolved_tool_skipped_in_user_message():
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
|
||||
# -- _end_text_if_open before compaction -------------------------------------
|
||||
|
||||
|
||||
def test_end_text_if_open_emits_text_end_before_finish_step():
|
||||
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
|
||||
|
||||
When ``emit_end_if_ready`` fires compaction events while a text block is
|
||||
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
|
||||
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
|
||||
and raises "Received text-end for missing text part".
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a text block by processing an AssistantMessage with text
|
||||
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.has_started_text
|
||||
assert not adapter.has_ended_text
|
||||
|
||||
# Simulate what service.py does before yielding compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
combined = pre_close + list(compaction_events("Compacted transcript"))
|
||||
|
||||
text_end_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
|
||||
)
|
||||
finish_step_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
|
||||
)
|
||||
|
||||
assert text_end_idx is not None, "StreamTextEnd must be present"
|
||||
assert finish_step_idx is not None, "StreamFinishStep must be present"
|
||||
assert text_end_idx < finish_step_idx, (
|
||||
f"StreamTextEnd (idx={text_end_idx}) must precede "
|
||||
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
|
||||
"clears activeTextParts before text-end arrives"
|
||||
)
|
||||
|
||||
|
||||
def test_step_open_must_reset_after_compaction_finish_step():
|
||||
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
|
||||
|
||||
Compaction events bypass the adapter, so service.py must explicitly clear
|
||||
step_open after yielding a StreamFinishStep from compaction. Without this,
|
||||
the next AssistantMessage skips StreamStartStep because the adapter still
|
||||
thinks a step is open.
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a step + text block via an AssistantMessage
|
||||
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.step_open is True
|
||||
|
||||
# Simulate what service.py does: close text, then check compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
|
||||
events = list(compaction_events("Compacted transcript"))
|
||||
if any(isinstance(ev, StreamFinishStep) for ev in events):
|
||||
adapter.step_open = False
|
||||
|
||||
assert (
|
||||
adapter.step_open is False
|
||||
), "step_open must be False after compaction emits StreamFinishStep"
|
||||
|
||||
# Next AssistantMessage must open a new step
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
|
||||
results = adapter.convert_message(msg2)
|
||||
assert any(
|
||||
isinstance(r, StreamStartStep) for r in results
|
||||
), "A new StreamStartStep must be emitted after compaction closed the step"
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_when_no_text_open():
|
||||
"""_end_text_if_open emits nothing when no text block is open."""
|
||||
adapter = _adapter()
|
||||
results: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(results)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
"""_end_text_if_open emits nothing when the text block is already closed."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
# Close it once
|
||||
first: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(first)
|
||||
assert len(first) == 1
|
||||
assert isinstance(first[0], StreamTextEnd)
|
||||
# Second call must be a no-op
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -113,7 +112,7 @@ class TestScenarioCompactAndRetry:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -170,7 +169,7 @@ class TestScenarioCompactFailsFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
@@ -261,7 +260,7 @@ class TestScenarioDoubleFailDBFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -337,7 +336,7 @@ class TestScenarioCompactionIdentical:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -730,7 +729,7 @@ class TestRetryEdgeCases:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -841,7 +840,7 @@ class TestRetryStateReset:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
@@ -1010,7 +1009,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value={})),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
@@ -1405,9 +1404,9 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
events.append(event)
|
||||
|
||||
# Should NOT retry — only 1 attempt for auth errors
|
||||
assert attempt_count[0] == 1, (
|
||||
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
|
||||
)
|
||||
assert (
|
||||
attempt_count[0] == 1
|
||||
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert errors, "Expected StreamError"
|
||||
assert errors[0].code == "sdk_stream_error"
|
||||
@@ -1487,3 +1486,188 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
|
||||
|
||||
The SDK internally compacts but the transcript is still too long. It
|
||||
returns subtype="success" (process completed) with result="Prompt is
|
||||
too long" (the actual rejection message). The retry loop must detect
|
||||
this as a context-length error and trigger compaction — the subtype
|
||||
"success" must not fool it into treating this as a real response.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="success",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
yield error_result
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
|
||||
f"result should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
|
||||
|
||||
The SDK returns error type "invalid_request" but puts the actual
|
||||
rejection message ("Prompt is too long") in the content blocks.
|
||||
The retry loop must detect this via content inspection (sdk_error
|
||||
being set confirms it's an error message, not user content).
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
# SDK returns invalid_request with "Prompt is too long" in content.
|
||||
# ResultMessage.result is a non-PTL value ("done") to isolate
|
||||
# the AssistantMessage content detection path exclusively.
|
||||
yield AssistantMessage(
|
||||
content=[TextBlock(text="Prompt is too long")],
|
||||
model="<synthetic>",
|
||||
error="invalid_request",
|
||||
)
|
||||
yield ResultMessage(
|
||||
subtype="success",
|
||||
result="done",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
|
||||
f"too long' should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -22,6 +22,38 @@ from .tool_adapter import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The SDK CLI uses "Task" in older versions and "Agent" in v2.x+.
|
||||
# Shared across all sessions — used by security hooks for sub-agent detection.
|
||||
_SUBAGENT_TOOLS: frozenset[str] = frozenset({"Task", "Agent"})
|
||||
|
||||
# Unicode ranges stripped by _sanitize():
|
||||
# - BiDi overrides (U+202A-U+202E, U+2066-U+2069) can trick reviewers
|
||||
# into misreading code/logs.
|
||||
# - Zero-width characters (U+200B-U+200F, U+FEFF) can hide content.
|
||||
_BIDI_AND_ZW_CHARS = set(
|
||||
chr(c)
|
||||
for r in (range(0x202A, 0x202F), range(0x2066, 0x206A), range(0x200B, 0x2010))
|
||||
for c in r
|
||||
) | {"\ufeff"}
|
||||
|
||||
|
||||
def _sanitize(value: str, max_len: int = 200) -> str:
|
||||
"""Strip control characters and truncate for safe logging.
|
||||
|
||||
Removes C0 (U+0000-U+001F), DEL (U+007F), C1 (U+0080-U+009F),
|
||||
Unicode BiDi overrides, and zero-width characters to prevent
|
||||
log injection and visual spoofing.
|
||||
"""
|
||||
cleaned = "".join(
|
||||
c
|
||||
for c in value
|
||||
if c >= " "
|
||||
and c != "\x7f"
|
||||
and not ("\x80" <= c <= "\x9f")
|
||||
and c not in _BIDI_AND_ZW_CHARS
|
||||
)
|
||||
return cleaned[:max_len]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
@@ -136,11 +168,13 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- SubagentStart: Log sub-agent lifecycle start
|
||||
- SubagentStop: Log sub-agent lifecycle end
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
max_subtasks: Maximum concurrent sub-agent spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
@@ -151,9 +185,19 @@ def create_security_hooks(
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session tracking for Task sub-agent concurrency.
|
||||
# Per-session tracking for sub-agent concurrency.
|
||||
# Set of tool_use_ids that consumed a slot — len() is the active count.
|
||||
task_tool_use_ids: set[str] = set()
|
||||
#
|
||||
# LIMITATION: For background (async) agents the SDK returns the
|
||||
# Agent/Task tool immediately with {isAsync: true}, which triggers
|
||||
# PostToolUse and releases the slot while the agent is still running.
|
||||
# SubagentStop fires later when the background process finishes but
|
||||
# does not currently hold a slot. This means the concurrency limit
|
||||
# only gates *launches*, not true concurrent execution. To fix this
|
||||
# we would need to track background agent_ids separately and release
|
||||
# in SubagentStop, but the SDK does not guarantee SubagentStop fires
|
||||
# for every background agent (e.g. on session abort).
|
||||
subagent_tool_use_ids: set[str] = set()
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
@@ -165,29 +209,22 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
# Rate-limit sub-agent spawns per session.
|
||||
# The SDK CLI renamed "Task" → "Agent" in v2.x; handle both.
|
||||
if tool_name in _SUBAGENT_TOOLS:
|
||||
# Background agents are allowed — the SDK returns immediately
|
||||
# with {isAsync: true} and the model polls via TaskOutput.
|
||||
# Still count them against the concurrency limit.
|
||||
if len(subagent_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
f"[SDK] Sub-agent limit reached ({max_subtasks}), "
|
||||
f"user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
f"Maximum {max_subtasks} concurrent sub-agents. "
|
||||
"Wait for running sub-agents to finish, "
|
||||
"or continue in the main conversation."
|
||||
),
|
||||
)
|
||||
@@ -208,20 +245,20 @@ def create_security_hooks(
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Reserve the Task slot only after all validations pass
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
# Reserve the sub-agent slot only after all validations pass
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id is not None:
|
||||
subagent_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a Task concurrency slot if one was reserved."""
|
||||
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
|
||||
task_tool_use_ids.discard(tool_use_id)
|
||||
def _release_subagent_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a sub-agent concurrency slot if one was reserved."""
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id in subagent_tool_use_ids:
|
||||
subagent_tool_use_ids.discard(tool_use_id)
|
||||
logger.info(
|
||||
"[SDK] Task slot released, active=%d/%d, user=%s",
|
||||
len(task_tool_use_ids),
|
||||
"[SDK] Sub-agent slot released, active=%d/%d, user=%s",
|
||||
len(subagent_tool_use_ids),
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
@@ -241,13 +278,14 @@ def create_security_hooks(
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""), max_len=12)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
@@ -256,7 +294,7 @@ def create_security_hooks(
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
resp_preview = _sanitize(str(tool_response), max_len=100)
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
@@ -280,13 +318,17 @@ def create_security_hooks(
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
error = _sanitize(str(input_data.get("error", "Unknown error")))
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""))
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
error,
|
||||
user_id,
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
@@ -301,16 +343,14 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
transcript_path = _sanitize(
|
||||
str(input_data.get("transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
@@ -322,6 +362,44 @@ def create_security_hooks(
|
||||
on_compact(transcript_path)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_start_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent starts execution."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
logger.info(
|
||||
"[SDK] SubagentStart: agent_id=%s, type=%s, user=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent stops."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
transcript = _sanitize(
|
||||
str(input_data.get("agent_transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] SubagentStop: agent_id=%s, type=%s, user=%s, transcript=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
transcript,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -329,6 +407,8 @@ def create_security_hooks(
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
"SubagentStart": [HookMatcher(matcher="*", hooks=[subagent_start_hook])],
|
||||
"SubagentStop": [HookMatcher(matcher="*", hooks=[subagent_stop_hook])],
|
||||
}
|
||||
|
||||
return hooks
|
||||
|
||||
@@ -5,6 +5,7 @@ They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@@ -136,8 +137,20 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_tool_outputs_allowed():
|
||||
"""tool-outputs/ paths should be allowed, same as tool-results/."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-outputs/12345.txt"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/tool-outputs is."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
@@ -233,16 +246,15 @@ def _hooks():
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
async def test_task_background_allowed(_hooks):
|
||||
"""Task with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
tool_use_id="tu-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@@ -356,3 +368,303 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# "Agent" tool name (SDK v2.x+ renamed "Task" → "Agent")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_background_allowed(_hooks):
|
||||
"""Agent with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "x"},
|
||||
},
|
||||
tool_use_id="tu-agent-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_foreground_allowed(_hooks):
|
||||
"""Agent without run_in_background should be allowed."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id="tu-agent-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_agent_counts_against_limit(_hooks):
|
||||
"""Background agents still consume concurrency slots."""
|
||||
pre, _, _ = _hooks
|
||||
# Two background agents fill the limit
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "bg"},
|
||||
},
|
||||
tool_use_id=f"tu-bglimit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
# Third (background or foreground) should be denied
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "over"},
|
||||
},
|
||||
tool_use_id="tu-bglimit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_limit_enforced(_hooks):
|
||||
"""Agent spawns beyond max_subtasks should be denied."""
|
||||
pre, _, _ = _hooks
|
||||
# First two should pass
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-limit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id="tu-agent-limit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_completion(_hooks):
|
||||
"""Completing an Agent should free a slot so new Agents can be spawned."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-comp-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied — at capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-comp-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Complete first agent — frees a slot
|
||||
await post(
|
||||
{"tool_name": "Agent", "tool_input": {}},
|
||||
tool_use_id="tu-agent-comp-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now a new Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after release"}},
|
||||
tool_use_id="tu-agent-comp-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_failure(_hooks):
|
||||
"""A failed Agent should also free its concurrency slot."""
|
||||
pre, _, post_failure = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-fail-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# At capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-fail-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Fail first agent — should free a slot
|
||||
await post_failure(
|
||||
{"tool_name": "Agent", "tool_input": {}, "error": "something broke"},
|
||||
tool_use_id="tu-agent-fail-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# New Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after failure"}},
|
||||
tool_use_id="tu-agent-fail-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_task_agent_share_slots(_hooks):
|
||||
"""Task and Agent share the same concurrency pool."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill one slot with Task, one with Agent
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-agent",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third (either name) should be denied
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-mix-over",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Release the Task slot
|
||||
await post(
|
||||
{"tool_name": "Task", "tool_input": {}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now an Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after task release"}},
|
||||
tool_use_id="tu-mix-new",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStart / SubagentStop hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _subagent_hooks():
|
||||
"""Create hooks and return (subagent_start, subagent_stop) handlers."""
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
start = hooks["SubagentStart"][0].hooks[0]
|
||||
stop = hooks["SubagentStop"][0].hooks[0]
|
||||
return start, stop
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_start_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStart hook should return an empty dict (logging only)."""
|
||||
start, _ = _subagent_hooks
|
||||
result = await start(
|
||||
{"agent_id": "sa-123", "agent_type": "research"},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_stop_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStop hook should return an empty dict (logging only)."""
|
||||
_, stop = _subagent_hooks
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa-123",
|
||||
"agent_type": "research",
|
||||
"agent_transcript_path": "/tmp/transcript.txt",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
|
||||
"""SubagentStart/Stop should sanitize control chars from inputs."""
|
||||
start, stop = _subagent_hooks
|
||||
# Inject control characters (C0, DEL, C1, BiDi overrides, zero-width)
|
||||
# — hook should not raise AND logs must be clean
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await start(
|
||||
{
|
||||
"agent_id": "sa\n-injected\r\x00\x7f",
|
||||
"agent_type": "safe\x80_type\x9f\ttab",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
# Control chars must be stripped from the logged values
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\x80" not in record.message
|
||||
assert "\x9f" not in record.message
|
||||
assert "safe_type" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa\n-injected\x7f",
|
||||
"agent_type": "type\r\x80\x9f",
|
||||
"agent_transcript_path": "/tmp/\x00malicious\npath\u202a\u200b",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\u202a" not in record.message
|
||||
assert "\u200b" not in record.message
|
||||
assert "/tmp/maliciouspath" in caplog.text
|
||||
|
||||
@@ -29,16 +29,29 @@ from claude_agent_sdk import (
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..config import ChatConfig, CopilotMode
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -51,7 +64,7 @@ from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
@@ -70,11 +83,7 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..service import _build_system_prompt, _is_langfuse_configured, _update_title_async
|
||||
from ..token_tracking import persist_and_record_usage
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
@@ -92,17 +101,6 @@ from .tool_adapter import (
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -129,6 +127,11 @@ _CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"Try breaking your request into smaller parts."
|
||||
)
|
||||
|
||||
# Idle timeout: abort the stream if no meaningful SDK message (only heartbeats)
|
||||
# arrives for this many seconds. This catches hung tool calls (e.g. WebSearch
|
||||
# hanging on a search provider that never responds).
|
||||
_IDLE_TIMEOUT_SECONDS = 10 * 60 # 10 minutes
|
||||
|
||||
# Patterns that indicate the prompt/request exceeds the model's context limit.
|
||||
# Matched case-insensitively against the full exception chain.
|
||||
_PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
|
||||
@@ -1271,6 +1274,8 @@ async def _run_stream_attempt(
|
||||
await client.query(state.query_message, session_id=ctx.session_id)
|
||||
state.transcript_builder.append_user(content=ctx.current_message)
|
||||
|
||||
_last_real_msg_time = time.monotonic()
|
||||
|
||||
async for sdk_msg in _iter_sdk_messages(client):
|
||||
# Heartbeat sentinel — refresh lock and keep SSE alive
|
||||
if sdk_msg is None:
|
||||
@@ -1278,8 +1283,34 @@ async def _run_stream_attempt(
|
||||
for ev in ctx.compaction.emit_start_if_ready():
|
||||
yield ev
|
||||
yield StreamHeartbeat()
|
||||
|
||||
# Idle timeout: if no real SDK message for too long, a tool
|
||||
# call is likely hung (e.g. WebSearch provider not responding).
|
||||
idle_seconds = time.monotonic() - _last_real_msg_time
|
||||
if idle_seconds >= _IDLE_TIMEOUT_SECONDS:
|
||||
logger.error(
|
||||
"%s Idle timeout after %.0fs with no SDK message — "
|
||||
"aborting stream (likely hung tool call)",
|
||||
ctx.log_prefix,
|
||||
idle_seconds,
|
||||
)
|
||||
stream_error_msg = (
|
||||
"A tool call appears to be stuck "
|
||||
"(no response for 10 minutes). "
|
||||
"Please try again."
|
||||
)
|
||||
stream_error_code = "idle_timeout"
|
||||
_append_error_marker(ctx.session, stream_error_msg, retryable=True)
|
||||
yield StreamError(
|
||||
errorText=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
continue
|
||||
|
||||
_last_real_msg_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
|
||||
ctx.log_prefix,
|
||||
@@ -1310,10 +1341,16 @@ async def _run_stream_attempt(
|
||||
# AssistantMessage.error (not as a Python exception).
|
||||
# Re-raise so the outer retry loop can compact the
|
||||
# transcript and retry with reduced context.
|
||||
# Only check error_text (the error field), not the
|
||||
# content preview — content may contain arbitrary text
|
||||
# that false-positives the pattern match.
|
||||
if _is_prompt_too_long(Exception(error_text)):
|
||||
# Check both error_text and error_preview: sdk_error
|
||||
# being set confirms this is an error message (not user
|
||||
# content), so checking content is safe. The actual
|
||||
# error description (e.g. "Prompt is too long") may be
|
||||
# in the content, not the error type field
|
||||
# (e.g. error="invalid_request", content="Prompt is
|
||||
# too long").
|
||||
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
|
||||
Exception(error_preview)
|
||||
):
|
||||
logger.warning(
|
||||
"%s Prompt-too-long detected via AssistantMessage "
|
||||
"error — raising for retry",
|
||||
@@ -1414,13 +1451,16 @@ async def _run_stream_attempt(
|
||||
ctx.log_prefix,
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
# If the CLI itself rejected the prompt as too long
|
||||
# (pre-API check, duration_api_ms=0), re-raise as an
|
||||
# exception so the retry loop can trigger compaction.
|
||||
# Without this, the ResultMessage is silently consumed
|
||||
# and the retry/compaction mechanism is never invoked.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Check for prompt-too-long regardless of subtype — the
|
||||
# SDK may return subtype="success" with result="Prompt is
|
||||
# too long" when the CLI rejects the prompt before calling
|
||||
# the API (cost_usd=0, no tokens consumed). If we only
|
||||
# check the "error" subtype path, the stream appears to
|
||||
# complete normally, the synthetic error text is stored
|
||||
# in the transcript, and the session grows without bound.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
@@ -1453,6 +1493,23 @@ async def _run_stream_attempt(
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
|
||||
if compact_result.events:
|
||||
# Compaction events end with StreamFinishStep, which maps to
|
||||
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
|
||||
# Close any open text block BEFORE the compaction events so
|
||||
# the text-end arrives before finish-step, preventing
|
||||
# "text-end for missing text part" errors on the frontend.
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(pre_close)
|
||||
# Compaction events bypass the adapter, so sync step state
|
||||
# when a StreamFinishStep is present — otherwise the adapter
|
||||
# will skip StreamStartStep on the next AssistantMessage.
|
||||
if any(
|
||||
isinstance(ev, StreamFinishStep) for ev in compact_result.events
|
||||
):
|
||||
state.adapter.step_open = False
|
||||
for r in pre_close:
|
||||
yield r
|
||||
for ev in compact_result.events:
|
||||
yield ev
|
||||
entries_replaced = False
|
||||
@@ -1502,9 +1559,21 @@ async def _run_stream_attempt(
|
||||
# --- Intermediate persistence ---
|
||||
# Flush session messages to DB periodically so page reloads
|
||||
# show progress during long-running turns.
|
||||
#
|
||||
# IMPORTANT: Skip the flush while tool calls are pending
|
||||
# (tool_calls set on assistant but results not yet received).
|
||||
# The DB save is append-only (uses start_sequence), so if we
|
||||
# flush the assistant message before tool_calls are set on it
|
||||
# (text and tool_use arrive as separate SDK events), the
|
||||
# tool_calls update is lost — the next flush starts past it.
|
||||
_msgs_since_flush += 1
|
||||
now = time.monotonic()
|
||||
if (
|
||||
has_pending_tools = (
|
||||
acc.has_appended_assistant
|
||||
and acc.accumulated_tool_calls
|
||||
and not acc.has_tool_results
|
||||
)
|
||||
if not has_pending_tools and (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
):
|
||||
@@ -1604,6 +1673,7 @@ async def stream_chat_completion_sdk(
|
||||
session: ChatSession | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
mode: CopilotMode | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1612,7 +1682,10 @@ async def stream_chat_completion_sdk(
|
||||
file_ids: Optional workspace file IDs attached to the user's message.
|
||||
Images are embedded as vision content blocks; other files are
|
||||
saved to the SDK working directory for the Read tool.
|
||||
mode: Accepted for signature compatibility with the baseline path.
|
||||
The SDK path does not currently branch on this value.
|
||||
"""
|
||||
_ = mode # SDK path ignores the requested mode.
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -1643,19 +1716,12 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
session.messages.pop()
|
||||
|
||||
# Append the new message to the session if it's not already there
|
||||
new_message_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_message_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
||||
if maybe_append_user_message(session, message, is_user_message):
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message or ""),
|
||||
)
|
||||
|
||||
# Structured log prefix: [SDK][<session>][T<turn>]
|
||||
@@ -1858,7 +1924,10 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
# sdk_cwd routes the CLI's temp dir into the per-session workspace
|
||||
# so sub-agent output files land inside sdk_cwd (see build_sdk_env).
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id, sdk_cwd=sdk_cwd)
|
||||
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
@@ -1917,15 +1986,20 @@ async def stream_chat_completion_sdk(
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
# is what Langfuse (or any OTEL backend) maps to its native
|
||||
# user/session fields.
|
||||
_user_tier = await get_user_tier(user_id) if user_id else None
|
||||
_otel_metadata: dict[str, str] = {
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
}
|
||||
if _user_tier:
|
||||
_otel_metadata["subscription_tier"] = _user_tier.value
|
||||
|
||||
_otel_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
},
|
||||
metadata=_otel_metadata,
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
@@ -2294,8 +2368,26 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# --- Close OTEL context ---
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens", turn_completion_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_creation_tokens",
|
||||
turn_cache_creation_tokens,
|
||||
)
|
||||
if turn_cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
|
||||
except Exception:
|
||||
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
|
||||
try:
|
||||
_otel_ctx.__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
@@ -2313,6 +2405,8 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
@@ -2417,18 +2511,3 @@ async def stream_chat_completion_sdk(
|
||||
finally:
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||
|
||||
@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
@@ -439,7 +438,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -498,7 +497,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
@@ -551,7 +550,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -601,7 +600,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -638,7 +637,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -699,7 +698,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
|
||||
@@ -38,7 +38,7 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS, bridge_and_annotate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -387,7 +387,16 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return _mcp_ok("".join(selected))
|
||||
#
|
||||
# When E2B is active, also copy the file into the sandbox so
|
||||
# bash_exec can process it (the model often uses Read then bash).
|
||||
text = "".join(selected)
|
||||
sandbox = _current_sandbox.get(None)
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(sandbox, resolved, offset, limit)
|
||||
if annotation:
|
||||
text += annotation
|
||||
return _mcp_ok(text)
|
||||
except FileNotFoundError:
|
||||
return _mcp_err(f"File not found: {file_path}")
|
||||
except Exception as e:
|
||||
@@ -581,13 +590,14 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# Security hooks validate that file paths stay within sdk_cwd.
|
||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# Task/Agent allows spawning sub-agents (rate-limited by security hooks).
|
||||
# The CLI renamed "Task" → "Agent" in v2.x; both are listed for compat.
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
|
||||
@@ -619,3 +619,95 @@ class TestSDKDisallowedTools:
|
||||
def test_webfetch_tool_is_disallowed(self):
|
||||
"""WebFetch is disallowed due to SSRF risk."""
|
||||
assert "WebFetch" in SDK_DISALLOWED_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — bridge_and_annotate integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileHandlerBridge:
|
||||
"""Verify that _read_file_handler calls bridge_and_annotate when a sandbox is active."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-bridge-test",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_called_when_sandbox_active(self, tmp_path, monkeypatch):
|
||||
"""When a sandbox is set, bridge_and_annotate is called and its annotation appended."""
|
||||
from backend.copilot.context import _current_sandbox
|
||||
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
fake_sandbox = object()
|
||||
token = _current_sandbox.set(fake_sandbox) # type: ignore[arg-type]
|
||||
try:
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 1
|
||||
assert bridge_calls[0][0] is fake_sandbox
|
||||
assert "/tmp/abc-data.json" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_sandbox.reset(token)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_not_called_without_sandbox(self, tmp_path, monkeypatch):
|
||||
"""When no sandbox is set, bridge_and_annotate is not called."""
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 0
|
||||
assert "Sandbox copy" not in result["content"][0]["text"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,235 +1,10 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
"""Re-export from shared ``backend.copilot.transcript_builder`` for backward compat.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
The canonical implementation now lives at ``backend.copilot.transcript_builder``
|
||||
so both the SDK and baseline paths can import without cross-package
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
__all__ = ["TranscriptBuilder", "TranscriptEntry"]
|
||||
|
||||
@@ -303,7 +303,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -323,7 +323,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -341,7 +341,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -850,7 +850,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
@@ -858,11 +858,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
@@ -885,7 +885,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
@@ -894,11 +894,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
@@ -916,7 +916,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
@@ -932,11 +932,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
@@ -953,7 +953,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
@@ -970,19 +970,19 @@ class TestRunCompression:
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
@@ -1007,7 +1007,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1039,12 +1039,12 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1062,7 +1062,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1114,11 +1114,11 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1160,12 +1160,12 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1181,7 +1181,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,12 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
from .model import (
|
||||
ChatSessionInfo,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -202,6 +207,22 @@ async def _generate_session_title(
|
||||
return None
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background.
|
||||
|
||||
Shared by both the SDK and baseline execution paths.
|
||||
"""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("Generated title for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update session title for %s: %s", session_id, e)
|
||||
|
||||
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
from .transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user