mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
231 Commits
dx/add-age
...
fix/artifa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
351001fdca | ||
|
|
3a01874911 | ||
|
|
6d770d9917 | ||
|
|
334ec18c31 | ||
|
|
ea5cfdfa2e | ||
|
|
d13a85bef7 | ||
|
|
60b85640e7 | ||
|
|
87e4d42750 | ||
|
|
0339d95d12 | ||
|
|
f410929560 | ||
|
|
2bbec09e1a | ||
|
|
31b88a6e56 | ||
|
|
d357956d98 | ||
|
|
697ffa81f0 | ||
|
|
2b4727e8b2 | ||
|
|
0d4b31e8a1 | ||
|
|
0cd0a76305 | ||
|
|
d01a51be0e | ||
|
|
bd2efed080 | ||
|
|
5fccd8a762 | ||
|
|
2740b2be3a | ||
|
|
d27d22159d | ||
|
|
fffbe0aad8 | ||
|
|
df205b5444 | ||
|
|
4efa1c4310 | ||
|
|
ab3221a251 | ||
|
|
b2f7faabc7 | ||
|
|
c9fa6bcd62 | ||
|
|
c955b3901c | ||
|
|
56864aea87 | ||
|
|
d23ca824ad | ||
|
|
227c60abd3 | ||
|
|
0284614df0 | ||
|
|
f835674498 | ||
|
|
da18f372f7 | ||
|
|
d82ecac363 | ||
|
|
8a2e2365f7 | ||
|
|
55869d3c75 | ||
|
|
142c5dbe99 | ||
|
|
b06648de8c | ||
|
|
7240dd4fb1 | ||
|
|
b4cd00bea9 | ||
|
|
e17914d393 | ||
|
|
b3a58389e5 | ||
|
|
a3846e1e74 | ||
|
|
e5b0b7f18e | ||
|
|
92575ae76b | ||
|
|
44b58ca22c | ||
|
|
9de22eb053 | ||
|
|
55fe900650 | ||
|
|
bc6709dda1 | ||
|
|
b2b6f75420 | ||
|
|
573fb7163f | ||
|
|
c0306b1d21 | ||
|
|
b319c26cab | ||
|
|
85921f227a | ||
|
|
5844b13fb1 | ||
|
|
c014e1aa35 | ||
|
|
e59f576622 | ||
|
|
c99fa32ae3 | ||
|
|
b71789da50 | ||
|
|
5661326e7e | ||
|
|
df3fe926f2 | ||
|
|
505af7e673 | ||
|
|
d896a1f9fa | ||
|
|
6aa5a808e0 | ||
|
|
18c88b4da0 | ||
|
|
3a5ce570e0 | ||
|
|
5a3739e54d | ||
|
|
72bc8a92df | ||
|
|
cc29cf5e20 | ||
|
|
a0efbbba90 | ||
|
|
8ed959433a | ||
|
|
98f3e09580 | ||
|
|
9ec44dd109 | ||
|
|
bfb82b6246 | ||
|
|
63210770ce | ||
|
|
f2b8f81bb1 | ||
|
|
68b51ae2d3 | ||
|
|
63ff214563 | ||
|
|
9498daca31 | ||
|
|
ce0cb1e035 | ||
|
|
0d89f7bb33 | ||
|
|
aef9298be6 | ||
|
|
e5ea2e0d5b | ||
|
|
4eabc48053 | ||
|
|
101504ce0b | ||
|
|
2f67249d5f | ||
|
|
e73b5b3692 | ||
|
|
57c0c86a10 | ||
|
|
77d8362983 | ||
|
|
201d88b846 | ||
|
|
611a00d930 | ||
|
|
8d31bdb2dc | ||
|
|
2e64f3add7 | ||
|
|
b7f242f163 | ||
|
|
98c0920c04 | ||
|
|
4942249a60 | ||
|
|
0c94d884d0 | ||
|
|
54eaf7b818 | ||
|
|
be86a911e1 | ||
|
|
89091cb90f | ||
|
|
54763b660b | ||
|
|
835c8b0230 | ||
|
|
87539c03a4 | ||
|
|
f112555fc3 | ||
|
|
4e4aafca45 | ||
|
|
e68dadd2c9 | ||
|
|
d113687878 | ||
|
|
34abaa5a76 | ||
|
|
369ce7da16 | ||
|
|
70d53a0926 | ||
|
|
642c72e5e5 | ||
|
|
ba7929205d | ||
|
|
06c8882222 | ||
|
|
6d60265221 | ||
|
|
7b30a57112 | ||
|
|
7a08d9e0ca | ||
|
|
7c3a6f597a | ||
|
|
0b8997eb01 | ||
|
|
2ff036b86b | ||
|
|
b2d89c3a66 | ||
|
|
1fc3cc74ea | ||
|
|
815659d188 | ||
|
|
8c228afb15 | ||
|
|
afc7d3b252 | ||
|
|
0bd9b58da2 | ||
|
|
ca1577f3b1 | ||
|
|
2f3b29f589 | ||
|
|
5d0330615f | ||
|
|
cc6bf13e16 | ||
|
|
fce353fb21 | ||
|
|
8b8eb80480 | ||
|
|
875852be32 | ||
|
|
1e8a0f8d53 | ||
|
|
a22693a878 | ||
|
|
bb79cefb05 | ||
|
|
d31ff0586e | ||
|
|
3e35345efb | ||
|
|
478b60ce5d | ||
|
|
824ba15ff9 | ||
|
|
907518bfc3 | ||
|
|
15cedc6d17 | ||
|
|
28e7772db6 | ||
|
|
c390ab13fd | ||
|
|
7acfdf5974 | ||
|
|
ef477ae4b9 | ||
|
|
2879470185 | ||
|
|
705bd27930 | ||
|
|
fa6ea36488 | ||
|
|
cab061a12d | ||
|
|
6552d9bfdd | ||
|
|
f32a4087df | ||
|
|
eede293e11 | ||
|
|
31a2371c26 | ||
|
|
21670b20de | ||
|
|
ff8cdda4e8 | ||
|
|
c51097d8ac | ||
|
|
f3306d9211 | ||
|
|
19c8aecb97 | ||
|
|
d8181e7624 | ||
|
|
a4282d927a | ||
|
|
1c43d4a81d | ||
|
|
2897550d21 | ||
|
|
e058671325 | ||
|
|
a955b017f1 | ||
|
|
5f55980669 | ||
|
|
7f642f5b64 | ||
|
|
b3f25ecb57 | ||
|
|
f5e2eccda7 | ||
|
|
8f855e5ea7 | ||
|
|
6ed257225f | ||
|
|
109f28d9d1 | ||
|
|
ffa955044d | ||
|
|
0999739d19 | ||
|
|
58b230ff5a | ||
|
|
77f41d0cc6 | ||
|
|
5e8530b263 | ||
|
|
817b80a198 | ||
|
|
fbbd222405 | ||
|
|
67bdef13e7 | ||
|
|
e67dd93ee8 | ||
|
|
3140a60816 | ||
|
|
41c2ee9f83 | ||
|
|
ca748ee12a | ||
|
|
243b12778f | ||
|
|
43c81910ae | ||
|
|
a11199aa67 | ||
|
|
5f82a71d5f | ||
|
|
1a305db162 | ||
|
|
48a653dc63 | ||
|
|
f6ddcbc6cb | ||
|
|
98f13a6e5d | ||
|
|
613978a611 | ||
|
|
2b0e8a5a9f | ||
|
|
08bb05141c | ||
|
|
3ccaa5e103 | ||
|
|
09e42041ce | ||
|
|
a50e95f210 | ||
|
|
92b395d82a | ||
|
|
86abfbd394 | ||
|
|
a7f4093424 | ||
|
|
e33b1e2105 | ||
|
|
fff101e037 | ||
|
|
f1ac05b2e0 | ||
|
|
f115607779 | ||
|
|
1aef8b7155 | ||
|
|
0da949ba42 | ||
|
|
6b031085bd | ||
|
|
11b846dd49 | ||
|
|
b9e29c96bd | ||
|
|
4ac0ba570a | ||
|
|
d61a2c6cd0 | ||
|
|
1c301b4b61 | ||
|
|
24d0c35ed3 | ||
|
|
8aae7751dc | ||
|
|
725da7e887 | ||
|
|
bd9e9ec614 | ||
|
|
88589764b5 | ||
|
|
c659f3b058 | ||
|
|
80581a8364 | ||
|
|
3c046eb291 | ||
|
|
3e25488b2d | ||
|
|
57b17dc8e1 | ||
|
|
a20188ae59 | ||
|
|
c410be890e | ||
|
|
37d9863552 | ||
|
|
2f42ff9b47 | ||
|
|
914efc53e5 | ||
|
|
17e78ca382 | ||
|
|
1750c833ee |
1
.agents/skills
Symbolic link
1
.agents/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
10
.claude/settings.json
Normal file
10
.claude/settings.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allowedTools": [
|
||||
"Read", "Grep", "Glob",
|
||||
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
|
||||
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
|
||||
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
709
.claude/skills/orchestrate/SKILL.md
Normal file
709
.claude/skills/orchestrate/SKILL.md
Normal file
@@ -0,0 +1,709 @@
|
||||
---
|
||||
name: orchestrate
|
||||
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
|
||||
user-invocable: true
|
||||
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "6.0.0"
|
||||
---
|
||||
|
||||
# Orchestrate — Agent Fleet Supervisor
|
||||
|
||||
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
|
||||
|
||||
## Scripts
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
|
||||
STATE_FILE=~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
|
||||
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
|
||||
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
|
||||
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
|
||||
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
|
||||
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
|
||||
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
|
||||
| `status.sh` | Print fleet status + live pane commands |
|
||||
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
|
||||
| `classify-pane.sh WINDOW` | Classify one pane state |
|
||||
|
||||
## Supervision model
|
||||
|
||||
```
|
||||
Orchestrating Claude (this Claude session — IS the supervisor)
|
||||
└── Reads pane output, checks CI, intervenes with targeted guidance
|
||||
run-loop.sh (separate tmux window, every 30s)
|
||||
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
|
||||
```
|
||||
|
||||
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
|
||||
|
||||
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
|
||||
|
||||
## Checkpoint protocol
|
||||
|
||||
Agents output checkpoints as they complete each required step:
|
||||
|
||||
```
|
||||
CHECKPOINT:<step-name>
|
||||
```
|
||||
|
||||
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
|
||||
|
||||
## Worktree lifecycle
|
||||
|
||||
```text
|
||||
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
|
||||
↓
|
||||
CHECKPOINT:<step> (as steps complete)
|
||||
↓
|
||||
ORCHESTRATOR:DONE
|
||||
↓
|
||||
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
|
||||
↓
|
||||
state → "done", notify, window KEPT OPEN
|
||||
↓
|
||||
user/orchestrator explicitly requests recycle
|
||||
↓
|
||||
recycle-agent.sh → spare/N (free again)
|
||||
```
|
||||
|
||||
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
|
||||
|
||||
**To resume a done or crashed session:**
|
||||
```bash
|
||||
# Resume by stored session ID (preferred — exact session, full context)
|
||||
claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
|
||||
# Or resume most recent session in that worktree directory
|
||||
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
|
||||
```
|
||||
|
||||
**To manually recycle when ready:**
|
||||
```bash
|
||||
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
|
||||
# Then update state:
|
||||
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## State file (`~/.claude/orchestrator-state.json`)
|
||||
|
||||
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`).
|
||||
|
||||
```json
|
||||
{
|
||||
"active": true,
|
||||
"tmux_session": "autogpt1",
|
||||
"idle_threshold_seconds": 300,
|
||||
"loop_window": "autogpt1:5",
|
||||
"repo": "Significant-Gravitas/AutoGPT",
|
||||
"discord_webhook": "https://discord.com/api/webhooks/...",
|
||||
"last_poll_at": 0,
|
||||
"agents": [
|
||||
{
|
||||
"window": "autogpt1:3",
|
||||
"worktree": "AutoGPT6",
|
||||
"worktree_path": "/path/to/AutoGPT6",
|
||||
"spare_branch": "spare/6",
|
||||
"branch": "feat/my-feature",
|
||||
"objective": "Implement X and open a PR",
|
||||
"pr_number": "12345",
|
||||
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"steps": ["pr-address", "pr-test"],
|
||||
"checkpoints": ["pr-address"],
|
||||
"state": "running",
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": 0,
|
||||
"spawned_at": 0,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Top-level optional fields:
|
||||
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
|
||||
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
|
||||
|
||||
Per-agent fields:
|
||||
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
|
||||
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
|
||||
|
||||
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
|
||||
|
||||
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
|
||||
|
||||
## Serial /pr-test rule
|
||||
|
||||
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
|
||||
|
||||
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
|
||||
|
||||
You (the orchestrating Claude) own the test queue:
|
||||
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
|
||||
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
|
||||
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
|
||||
4. Feed results back to the relevant agent via `tmux send-keys`:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
5. Wait for CI to confirm green before marking the agent done.
|
||||
|
||||
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
|
||||
|
||||
## Session restore (tested and confirmed)
|
||||
|
||||
Agent sessions are saved to disk. To restore a closed or crashed session:
|
||||
|
||||
```bash
|
||||
# If session_id is in state (preferred):
|
||||
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
|
||||
|
||||
# If no session_id (use --continue for most recent session in that directory):
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
|
||||
```
|
||||
|
||||
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
|
||||
|
||||
```bash
|
||||
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
|
||||
'(.agents[] | select(.window == $old)).window = $new' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## Intent → action mapping
|
||||
|
||||
Match the user's message to one of these intents:
|
||||
|
||||
| The user says something like… | What to do |
|
||||
|---|---|
|
||||
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
|
||||
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
|
||||
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
|
||||
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
|
||||
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
|
||||
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
|
||||
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
|
||||
|
||||
When the intent is ambiguous, show capacity first and ask what tasks to run.
|
||||
|
||||
## Spawning agents
|
||||
|
||||
### 1. Resolve tmux session
|
||||
|
||||
```bash
|
||||
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
|
||||
```
|
||||
|
||||
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
|
||||
|
||||
### 2. Show available capacity
|
||||
|
||||
```bash
|
||||
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
|
||||
```
|
||||
|
||||
### 3. Collect tasks from the user
|
||||
|
||||
For each task, gather:
|
||||
- **objective** — what to do (e.g. "implement feature X and open a PR")
|
||||
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
|
||||
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
|
||||
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
|
||||
|
||||
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
|
||||
|
||||
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
|
||||
|
||||
### 4. Spawn one agent per task
|
||||
|
||||
```bash
|
||||
# Get ordered list of spare worktrees
|
||||
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
|
||||
|
||||
# For each task, take the next spare line:
|
||||
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
|
||||
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
|
||||
|
||||
# With PR number and required steps:
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
|
||||
|
||||
# Without PR (new work):
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
|
||||
```
|
||||
|
||||
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
|
||||
|
||||
```bash
|
||||
# Derive repo from git remote (used by verify-complete.sh + supervisor)
|
||||
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
|
||||
jq -n \
|
||||
--arg session "$SESSION" \
|
||||
--arg repo "$REPO" \
|
||||
--argjson threshold 300 \
|
||||
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
|
||||
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
|
||||
> ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
Optionally add a Discord webhook for completion notifications:
|
||||
```bash
|
||||
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
|
||||
|
||||
### 5. Start the mechanical babysitter
|
||||
|
||||
```bash
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
### 6. Begin supervising directly in this conversation
|
||||
|
||||
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
|
||||
|
||||
## Adding an agent
|
||||
|
||||
Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user.
|
||||
|
||||
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
|
||||
|
||||
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
|
||||
|
||||
### Poll loop mechanism
|
||||
|
||||
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
|
||||
|
||||
1. Start each poll with `run_in_background: true` + a sleep before the work:
|
||||
```bash
|
||||
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
|
||||
# + similar for each active window
|
||||
```
|
||||
2. When the background job notifies you, read the pane output and take action.
|
||||
3. Immediately schedule the next background poll — this keeps the loop alive.
|
||||
4. Stop scheduling when all agents are done/escalated.
|
||||
|
||||
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
|
||||
|
||||
### Each poll: what to check
|
||||
|
||||
```bash
|
||||
# 1. Read state
|
||||
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
|
||||
|
||||
# 2. For each running/stuck/idle agent, capture pane
|
||||
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
|
||||
```
|
||||
|
||||
For each agent, decide:
|
||||
|
||||
| What you see | Action |
|
||||
|---|---|
|
||||
| Spinner / tools running | Do nothing — agent is working |
|
||||
| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
|
||||
| Stuck in error loop | Send targeted fix with exact error + solution |
|
||||
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
|
||||
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
|
||||
| GitHub abuse rate limit error | Nudge: "Wait 60 seconds then continue posting replies with sleep 3 between each" |
|
||||
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
|
||||
| `ORCHESTRATOR:DONE` in output | Query GraphQL for actual unresolved count. If >0, re-brief. If 0, run `verify-complete.sh` |
|
||||
|
||||
**Poll all windows from state, not from memory.** Before each poll, run:
|
||||
```bash
|
||||
jq -r '.agents[] | select(.state | test("running|idle|stuck|waiting_approval|pending_evaluation")) | .window' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
and capture every window listed. If you manually added a window outside spawn-agent.sh, ensure it's in the state file first.
|
||||
|
||||
### RUNNING count includes waiting_approval agents
|
||||
|
||||
The `RUNNING` count from run-loop.sh includes agents in `waiting_approval` state (they match the regex `running|stuck|waiting_approval|idle`). This means a fleet that is only `waiting_approval` still shows RUNNING > 0 in the log — it does **not** mean agents are actively working.
|
||||
|
||||
When you see `RUNNING > 0` in the run-loop log but suspect agents are actually blocked, check state directly:
|
||||
```bash
|
||||
jq '.agents[] | {window, state, worktree}' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
A count of `running=1 waiting=1` in the log actually means one agent is waiting for approval — the orchestrator should check and approve, not wait.
|
||||
|
||||
### State file staleness recovery
|
||||
|
||||
The state file is written by scripts but can drift from reality when windows are closed, sessions expire, or the orchestrator restarts across conversations.
|
||||
|
||||
**Signs of stale state:**
|
||||
- `loop_window` points to a window that no longer exists in the tmux session
|
||||
- An agent's `state` is `running` but tmux window is closed or shows a shell prompt (not claude)
|
||||
- `last_seen_at` is hours old but state still says `running`
|
||||
|
||||
**Recovery steps:**
|
||||
|
||||
1. **Verify actual tmux windows:**
|
||||
```bash
|
||||
tmux list-windows -t SESSION -F '#{window_index}: #{window_name} (#{pane_current_command})'
|
||||
```
|
||||
|
||||
2. **Cross-reference with state file:**
|
||||
```bash
|
||||
jq -r '.agents[] | "\(.window) \(.state) \(.worktree)"' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
3. **Fix stale entries:**
|
||||
```bash
|
||||
# Agent window closed — mark idle so run-loop.sh will restart it
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window==$w)).state = "idle"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
# loop_window gone — kill the stale reference, then restart run-loop.sh
|
||||
jq '.loop_window = null' ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
4. **After any state repair, re-run `status.sh` to confirm coherence before resuming supervision.**
|
||||
|
||||
### Strict ORCHESTRATOR:DONE gate
|
||||
|
||||
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
|
||||
|
||||
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=~/.claude/orchestrator/scripts
|
||||
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
|
||||
```
|
||||
|
||||
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
|
||||
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
|
||||
|
||||
### Re-brief a stalled agent
|
||||
|
||||
**Before sending any nudge, verify the pane is at an idle ❯ prompt.** Sending text into a still-processing pane produces stuck `[Pasted text +N lines]` that the agent never sees.
|
||||
|
||||
Check:
|
||||
```bash
|
||||
tmux capture-pane -t SESSION:WIN -p 2>/dev/null | tail -5
|
||||
```
|
||||
If the last line shows a spinner (✳✽✢✶·), `Running…`, or no `❯` — wait 10–15s and check again before sending.
|
||||
|
||||
```bash
|
||||
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
|
||||
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
|
||||
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
|
||||
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
|
||||
|
||||
## Self-recovery protocol (agents)
|
||||
|
||||
spawn-agent.sh automatically includes this instruction in every objective:
|
||||
|
||||
> If your context compacts and you lose track of what to do, run:
|
||||
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
|
||||
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
|
||||
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
|
||||
|
||||
## Passing images and screenshots to agents
|
||||
|
||||
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
|
||||
|
||||
1. **Save the image to a temp file** with a stable path:
|
||||
```bash
|
||||
# If the user drags in a screenshot or you receive a file path:
|
||||
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
|
||||
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
|
||||
```
|
||||
|
||||
2. **Reference the path in the objective string**:
|
||||
```bash
|
||||
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
|
||||
```
|
||||
|
||||
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
|
||||
|
||||
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator final evaluation (YOU decide, not the script)
|
||||
|
||||
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
|
||||
|
||||
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
|
||||
|
||||
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
|
||||
|
||||
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
|
||||
|
||||
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
|
||||
```
|
||||
- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/NNNN — <feature description>
|
||||
- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/MMMM — <feature description>
|
||||
```
|
||||
Run one at a time. Check off as you go.
|
||||
|
||||
```
|
||||
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
|
||||
```
|
||||
|
||||
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
|
||||
|
||||
```
|
||||
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
|
||||
Context: This PR implements <objective from state file>. Key files: <list>.
|
||||
Please verify: <specific behaviors to check>.
|
||||
```
|
||||
|
||||
Only one `/pr-test` at a time — they share ports and DB.
|
||||
|
||||
### /pr-test result evaluation
|
||||
|
||||
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
|
||||
|
||||
| `/pr-test` result | Action |
|
||||
|---|---|
|
||||
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
|
||||
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
|
||||
| Any headline scenario **FAIL** | Re-brief the agent immediately |
|
||||
|
||||
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
|
||||
|
||||
**When any headline scenario is PARTIAL or FAIL:**
|
||||
|
||||
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
|
||||
2. Re-brief the agent with the specific scenario that failed and what was missing:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
3. Set state back to `running`:
|
||||
```bash
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
|
||||
|
||||
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
|
||||
|
||||
> **Why this matters**: A PR was once wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
|
||||
|
||||
### 2. Do your own evaluation
|
||||
|
||||
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
|
||||
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
|
||||
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
|
||||
4. **Check the PR description** — title, summary, test plan complete?
|
||||
|
||||
### 3. Decide
|
||||
|
||||
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
|
||||
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
|
||||
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
|
||||
|
||||
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
|
||||
|
||||
```bash
|
||||
# Mark done after your positive evaluation:
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## When to stop the fleet
|
||||
|
||||
Stop the fleet (`active = false`) when **all** of the following are true:
|
||||
|
||||
| Check | How to verify |
|
||||
|---|---|
|
||||
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
|
||||
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
|
||||
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
|
||||
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
|
||||
| No agents are `escalated` without human review | If any are escalated, surface to user first |
|
||||
|
||||
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
|
||||
|
||||
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
|
||||
|
||||
```bash
|
||||
# Graceful stop
|
||||
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
|
||||
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
|
||||
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
|
||||
```
|
||||
|
||||
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
|
||||
|
||||
## tmux send-keys pattern
|
||||
|
||||
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
|
||||
|
||||
```bash
|
||||
# CORRECT — text then Enter separately
|
||||
tmux send-keys -t "$WINDOW" "your long message here"
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# WRONG — Enter may fire before text is buffered
|
||||
tmux send-keys -t "$WINDOW" "your long message here" Enter
|
||||
```
|
||||
|
||||
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
|
||||
|
||||
---
|
||||
|
||||
## Protected worktrees
|
||||
|
||||
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
|
||||
|
||||
| Worktree | Protected branch | Why |
|
||||
|---|---|---|
|
||||
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
|
||||
|
||||
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
|
||||
|
||||
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
|
||||
|
||||
---
|
||||
|
||||
## Thread resolution integrity (critical)
|
||||
|
||||
**Agents MUST NOT resolve review threads via GraphQL unless a real code fix has been committed and pushed first.**
|
||||
|
||||
This is the most common failure mode: agents call `resolveReviewThread` to make unresolved counts drop without actually fixing anything. This produces a false "done" signal that gets past verify-complete.sh.
|
||||
|
||||
**The only valid resolution sequence:**
|
||||
1. Read the thread and understand what it's asking
|
||||
2. Make the actual code change
|
||||
3. `git commit` and `git push`
|
||||
4. Reply to the thread with the commit SHA (e.g. "Fixed in `abc1234`")
|
||||
5. THEN call `resolveReviewThread`
|
||||
|
||||
**The supervisor must verify actual thread counts via GraphQL** — never trust an agent's claim of "0 unresolved." After any agent's ORCHESTRATOR:DONE, always run:
|
||||
|
||||
```bash
|
||||
# Step 1: get total count
|
||||
TOTAL=$(gh api graphql -f query='{ repository(owner: "OWNER", name: "REPO") { pullRequest(number: PR) { reviewThreads { totalCount } } } }' \
|
||||
| jq '.data.repository.pullRequest.reviewThreads.totalCount')
|
||||
echo "Total threads: $TOTAL"
|
||||
|
||||
# Step 2: paginate all pages and count unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="{ repository(owner: \"OWNER\", name: \"REPO\") { pullRequest(number: PR) { reviewThreads(first: 100${AFTER}) { pageInfo { hasNextPage endCursor } nodes { isResolved } } } } }")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved: $UNRESOLVED"
|
||||
```
|
||||
|
||||
If unresolved > 0, the agent is NOT done — re-brief with the actual count and the rule.
|
||||
|
||||
**Include this in every agent objective:**
|
||||
> IMPORTANT: Do NOT resolve any review thread via GraphQL unless the code fix is committed and pushed first. Fix the code → commit → push → reply with SHA → then resolve. Never resolve without a real commit. "Accepted" or "Acknowledged" replies are NOT resolutions — only real commits qualify.
|
||||
|
||||
### Detecting fake resolutions
|
||||
|
||||
When an agent claims "0 unresolved threads", query GitHub GraphQL yourself and also inspect how each thread was resolved. A resolved thread whose last comment is `"Acknowledged"`, `"Same as above"`, `"Accepted trade-off"`, or `"Deferred"` — with no commit SHA — is a fake resolution.
|
||||
|
||||
To spot these, paginate all pages and collect resolved threads with missing SHA links:
|
||||
```bash
|
||||
# Paginate all pages — first:100 misses threads beyond page 1 on large PRs
|
||||
CURSOR=""; FAKE_RESOLUTIONS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: PR_NUMBER) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
isResolved
|
||||
comments(last: 1) {
|
||||
nodes { body author { login } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
PAGE_FAKES=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[]
|
||||
| select(.isResolved == true)
|
||||
| {body: .comments.nodes[0].body[:120], author: .comments.nodes[0].author.login}
|
||||
| select(.body | test("Fixed in|Removed in|Addressed in") | not)]')
|
||||
FAKE_RESOLUTIONS=$(echo "$FAKE_RESOLUTIONS $PAGE_FAKES" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "$FAKE_RESOLUTIONS"
|
||||
```
|
||||
Any resolved thread whose last comment does NOT contain `"Fixed in"`, `"Removed in"`, or `"Addressed in"` (with a commit link) should be investigated — either the agent falsely resolved it, or it was a genuine false positive that needs explanation.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Two distinct rate limits exist with different recovery times:
|
||||
|
||||
| Error | HTTP status | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` in body | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `API rate limit exceeded` | 429 | Primary rate limit — too many read calls per hour | Wait until `X-RateLimit-Reset` timestamp |
|
||||
|
||||
**Prevention:** Agents must add `sleep 3` between individual thread reply API calls. For >20 unresolved threads, increase to `sleep 5`.
|
||||
|
||||
If you see a 403 `abuse` error from an agent's pane:
|
||||
1. Nudge the agent: `"You hit a GitHub secondary rate limit (403). Stop all API writes. Wait 2 minutes, then resume with sleep 3 between each thread reply."`
|
||||
2. Do NOT nudge again during the 2-minute wait — a second nudge restarts the clock.
|
||||
|
||||
Add this to agent briefings when there are >20 unresolved threads:
|
||||
> Post replies with `sleep 3` between each reply. If you hit a 403 abuse error, wait 2 minutes (not 60s — secondary limits take longer to clear) then continue.
|
||||
|
||||
## Key rules
|
||||
|
||||
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
|
||||
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
|
||||
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
|
||||
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
|
||||
5. **Always `--permission-mode bypassPermissions`** on every spawn
|
||||
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
|
||||
7. **Atomic state writes** — always write to `.tmp` then `mv`
|
||||
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
|
||||
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
|
||||
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
|
||||
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
|
||||
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
|
||||
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
|
||||
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
|
||||
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
|
||||
16. **Poll ALL windows from state file** — never hardcode window count. Derive active windows dynamically: `jq -r '.agents[] | select(.state | test("running|idle|stuck")) | .window' ~/.claude/orchestrator-state.json`. If you added a window mid-session outside spawn-agent.sh, add it to the state file immediately.
|
||||
20. **Orchestrator handles its own approvals** — when spawning a subagent to make edits (SKILL.md, scripts, config), review the diff yourself and approve/reject without surfacing it to the user. The user should never have to open a file to check the orchestrator's work. Use the Agent tool with `subagent_type: general-purpose` for drafting, then verify the result yourself before considering the task done.
|
||||
17. **Update state file on re-task** — whenever an agent is re-tasked mid-session (objective changes, new PR assigned), update the state file record immediately so objectives stay accurate for re-briefing after compaction.
|
||||
18. **No GraphQL resolveReviewThread without a commit** — see Thread resolution integrity above. This is rule #1 for pr-address work.
|
||||
19. **Verify thread counts yourself** — after any agent claims "0 unresolved threads", query GitHub GraphQL directly before accepting. Never trust the agent's self-report.
|
||||
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
|
||||
#
|
||||
# Usage: capacity.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree of the current git repo.
|
||||
#
|
||||
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
|
||||
echo "=== Available (spare) worktrees ==="
|
||||
if [ -n "$REPO_ROOT" ]; then
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
|
||||
else
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$SPARE" ]; then
|
||||
echo " (none)"
|
||||
else
|
||||
while IFS= read -r line; do
|
||||
[ -z "$line" ] && continue
|
||||
echo " ✓ $line"
|
||||
done <<< "$SPARE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== In-use worktrees ==="
|
||||
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
|
||||
"$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -n "$IN_USE" ]; then
|
||||
echo "$IN_USE"
|
||||
else
|
||||
echo " (none)"
|
||||
fi
|
||||
else
|
||||
echo " (no active state file)"
|
||||
fi
|
||||
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env bash
|
||||
# classify-pane.sh — Classify the current state of a tmux pane
|
||||
#
|
||||
# Usage: classify-pane.sh <tmux-target>
|
||||
# tmux-target: e.g. "work:0", "work:1.0"
|
||||
#
|
||||
# Output (stdout): JSON object:
|
||||
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
|
||||
#
|
||||
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TARGET="${1:-}"
|
||||
|
||||
if [ -z "$TARGET" ]; then
|
||||
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate tmux target format: session:window or session:window.pane
|
||||
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check session exists (use %%:* to extract session name from session:window)
|
||||
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
|
||||
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the current foreground command in the pane
|
||||
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
|
||||
|
||||
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
|
||||
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
|
||||
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
|
||||
| grep -v '^[[:space:]]*$' || true)
|
||||
|
||||
# --- Check: explicit completion marker ---
|
||||
# Must be on its own line (not buried in the objective text sent at spawn time).
|
||||
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
|
||||
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Check: Claude Code approval prompt patterns ---
|
||||
LAST_40=$(echo "$CLEAN" | tail -40)
|
||||
APPROVAL_PATTERNS=(
|
||||
"Do you want to proceed"
|
||||
"Do you want to make this"
|
||||
"\\[y/n\\]"
|
||||
"\\[Y/n\\]"
|
||||
"\\[n/Y\\]"
|
||||
"Proceed\\?"
|
||||
"Allow this command"
|
||||
"Run bash command"
|
||||
"Allow bash"
|
||||
"Would you like"
|
||||
"Press enter to continue"
|
||||
"Esc to cancel"
|
||||
)
|
||||
for pattern in "${APPROVAL_PATTERNS[@]}"; do
|
||||
if echo "$LAST_40" | grep -qiE "$pattern"; then
|
||||
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
|
||||
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Check: shell prompt (claude has exited) ---
|
||||
# If the foreground process is a shell (not claude/node), the agent has exited
|
||||
case "$PANE_CMD" in
|
||||
zsh|bash|fish|sh|dash|tcsh|ksh)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
# Agent is still running (claude/node/python is the foreground process)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
# find-spare.sh — list worktrees on spare/N branches (free to use)
|
||||
#
|
||||
# Usage: find-spare.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree containing the current git repo.
|
||||
#
|
||||
# Output (stdout): one line per available worktree: "PATH BRANCH"
|
||||
# e.g.: /Users/me/Code/AutoGPT3 spare/3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
if [ -z "$REPO_ROOT" ]; then
|
||||
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git -C "$REPO_ROOT" worktree list --porcelain \
|
||||
| awk '
|
||||
/^worktree / { path = substr($0, 10) }
|
||||
/^branch / { branch = substr($0, 8); print path " " branch }
|
||||
' \
|
||||
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
|
||||
| sed 's|refs/heads/||'
|
||||
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
# notify.sh — send a fleet notification message
|
||||
#
|
||||
# Delivery order (first available wins):
|
||||
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
|
||||
# 2. macOS notification center — osascript (silent fail if unavailable)
|
||||
# 3. Stdout only
|
||||
#
|
||||
# Usage: notify.sh MESSAGE
|
||||
# Exit: always 0 (notification failure must not abort the caller)
|
||||
|
||||
MESSAGE="${1:-}"
|
||||
[ -z "$MESSAGE" ] && exit 0
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# --- Resolve Discord webhook ---
|
||||
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
|
||||
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
|
||||
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# --- Discord delivery ---
|
||||
if [ -n "$WEBHOOK" ]; then
|
||||
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
|
||||
curl -s -X POST "$WEBHOOK" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" > /dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
# --- macOS notification center (silent if not macOS or osascript missing) ---
|
||||
if command -v osascript &>/dev/null 2>&1; then
|
||||
# Escape single quotes for AppleScript
|
||||
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
|
||||
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Always print to stdout so run-loop.sh logs it
|
||||
echo "$MESSAGE"
|
||||
exit 0
|
||||
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env bash
|
||||
# poll-cycle.sh — Single orchestrator poll cycle
|
||||
#
|
||||
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
|
||||
# and outputs a JSON array of actions for Claude to take.
|
||||
#
|
||||
# Usage: poll-cycle.sh
|
||||
# Output (stdout): JSON array of action objects
|
||||
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
|
||||
# "worktree": "...", "objective": "...", "reason": "..." }]
|
||||
#
|
||||
# The state file is updated in-place (atomic write via .tmp).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
|
||||
|
||||
# Cross-platform md5: always outputs just the hex digest
|
||||
md5_hash() {
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum | awk '{print $1}'
|
||||
else
|
||||
md5 | awk '{print $NF}'
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
|
||||
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
|
||||
|
||||
# Ensure state file exists
|
||||
if [ ! -f "$STATE_FILE" ]; then
|
||||
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Validate JSON upfront before any jq reads that run under set -e.
|
||||
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
|
||||
# abort the script at the ACTIVE read below without emitting any JSON output.
|
||||
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
|
||||
if [ "$ACTIVE" != "true" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NOW=$(date +%s)
|
||||
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
|
||||
|
||||
ACTIONS="[]"
|
||||
UPDATED_AGENTS="[]"
|
||||
|
||||
# Read agents as newline-delimited JSON objects.
|
||||
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
|
||||
# so we suppress that exit code and separately validate the file is well-formed JSON.
|
||||
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
|
||||
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
fi
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "$AGENTS_JSON" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while IFS= read -r agent; do
|
||||
[ -z "$agent" ] && continue
|
||||
|
||||
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
|
||||
WINDOW=$(echo "$agent" | jq -r '.window // ""')
|
||||
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
|
||||
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
|
||||
STATE=$(echo "$agent" | jq -r '.state // "running"')
|
||||
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
|
||||
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
|
||||
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
|
||||
|
||||
# Validate window format to prevent tmux target injection.
|
||||
# Allow session:window (numeric or named) and session:window.pane
|
||||
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo "Skipping agent with invalid window value: $WINDOW" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Pass-through terminal-state agents
|
||||
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Classify pane.
|
||||
# classify-pane.sh always emits JSON before exit (even on error), so using
|
||||
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
|
||||
# Use "|| true" inside the substitution so set -euo pipefail does not abort
|
||||
# the poll cycle when classify exits with a non-zero status code.
|
||||
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
|
||||
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
|
||||
|
||||
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
|
||||
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
|
||||
|
||||
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
|
||||
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
|
||||
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
|
||||
|
||||
# --- Checkpoint tracking ---
|
||||
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
|
||||
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
|
||||
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
|
||||
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
|
||||
if [ -n "$RAW" ]; then
|
||||
FOUND_CPS=$(echo "$RAW" \
|
||||
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
|
||||
| sed 's/CHECKPOINT://' \
|
||||
| sort -u \
|
||||
| jq -R . | jq -s . 2>/dev/null || echo "[]")
|
||||
NEW_CHECKPOINTS_JSON=$(jq -n \
|
||||
--argjson existing "$EXISTING_CPS" \
|
||||
--argjson found "$FOUND_CPS" \
|
||||
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
|
||||
fi
|
||||
|
||||
# Compute content hash for stuck-detection (only for running agents)
|
||||
CURRENT_HASH=""
|
||||
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
|
||||
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
|
||||
fi
|
||||
|
||||
NEW_STATE="$STATE"
|
||||
NEW_IDLE_SINCE="$IDLE_SINCE"
|
||||
NEW_REVISION_COUNT="$REVISION_COUNT"
|
||||
ACTION="none"
|
||||
REASON="$PANE_REASON"
|
||||
|
||||
case "$PANE_STATE" in
|
||||
complete)
|
||||
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
|
||||
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
|
||||
NEW_STATE="pending_evaluation"
|
||||
ACTION="complete" # run-loop logs it but takes no action
|
||||
;;
|
||||
waiting_approval)
|
||||
NEW_STATE="waiting_approval"
|
||||
ACTION="approve"
|
||||
;;
|
||||
idle)
|
||||
# Agent process has exited — needs restart
|
||||
NEW_STATE="idle"
|
||||
ACTION="kick"
|
||||
REASON="agent exited (shell is foreground)"
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
fi
|
||||
;;
|
||||
running)
|
||||
# Clear idle_since only when transitioning from idle (agent was kicked and
|
||||
# restarted). Do NOT reset for stuck — idle_since must persist across polls
|
||||
# so STUCK_DURATION can accumulate and trigger escalation.
|
||||
# Also update the local IDLE_SINCE so the hash-stability check below uses
|
||||
# the reset value on this same poll, not the stale kick timestamp.
|
||||
if [[ "$STATE" == "idle" ]]; then
|
||||
NEW_IDLE_SINCE=0
|
||||
IDLE_SINCE=0
|
||||
fi
|
||||
# Check if hash has been stable (agent may be stuck mid-task)
|
||||
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
|
||||
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
else
|
||||
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
|
||||
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
else
|
||||
NEW_STATE="stuck"
|
||||
ACTION="kick"
|
||||
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# Only reset the idle timer when we have a valid hash comparison (pane
|
||||
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
|
||||
# preserve existing timers so stuck detection is not inadvertently reset.
|
||||
if [ -n "$CURRENT_HASH" ]; then
|
||||
NEW_STATE="running"
|
||||
NEW_IDLE_SINCE=0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
error)
|
||||
REASON="classify error: $PANE_REASON"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build updated agent record (ensure idle_since and revision_count are numeric)
|
||||
# Use || true on each jq call so a malformed field skips this agent rather than
|
||||
# aborting the entire poll cycle under set -e.
|
||||
UPDATED_AGENT=$(echo "$agent" | jq \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg hash "$CURRENT_HASH" \
|
||||
--argjson now "$NOW" \
|
||||
--arg idle_since "$NEW_IDLE_SINCE" \
|
||||
--arg revision_count "$NEW_REVISION_COUNT" \
|
||||
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
|
||||
'.state = $state
|
||||
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
|
||||
| .last_seen_at = $now
|
||||
| .idle_since = ($idle_since | tonumber)
|
||||
| .revision_count = ($revision_count | tonumber)
|
||||
| .checkpoints = $checkpoints' 2>/dev/null) || {
|
||||
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
}
|
||||
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
|
||||
|
||||
# Add action if needed
|
||||
if [ "$ACTION" != "none" ]; then
|
||||
ACTION_OBJ=$(jq -n \
|
||||
--arg window "$WINDOW" \
|
||||
--arg action "$ACTION" \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg worktree "$WORKTREE" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg reason "$REASON" \
|
||||
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
|
||||
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
|
||||
fi
|
||||
|
||||
done <<< "$AGENTS_JSON"
|
||||
|
||||
# Atomic state file update
|
||||
jq --argjson agents "$UPDATED_AGENTS" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents = $agents | .last_poll_at = $now' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
|
||||
echo "$ACTIONS"
|
||||
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
|
||||
#
|
||||
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
|
||||
# WINDOW — tmux target, e.g. autogpt1:3
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — branch to restore, e.g. spare/6
|
||||
#
|
||||
# Stdout: one status line
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WINDOW="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
|
||||
# Kill the tmux window (ignore error — may already be gone)
|
||||
tmux kill-window -t "$WINDOW" 2>/dev/null || true
|
||||
|
||||
# Restore to spare branch: abort any in-progress operation, then clean
|
||||
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
|
||||
|
||||
echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)"
|
||||
215
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
215
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env bash
|
||||
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
|
||||
#
|
||||
# Handles ONLY two things that need no intelligence:
|
||||
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
|
||||
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
|
||||
#
|
||||
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
|
||||
# marking done, deciding to close windows — is the orchestrating Claude's job.
|
||||
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
|
||||
# the orchestrator's background poll loop handles it from there.
|
||||
#
|
||||
# Usage: run-loop.sh
|
||||
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Copy scripts to a stable location outside the repo so they survive branch
|
||||
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
|
||||
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
|
||||
# current branch).
|
||||
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
|
||||
mkdir -p "$STABLE_SCRIPTS_DIR"
|
||||
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
|
||||
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
|
||||
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
# Adaptive polling: starts at base interval, backs off up to POLL_IDLE_MAX when
|
||||
# no agents need attention, resets on any activity or waiting_approval state.
|
||||
POLL_INTERVAL="${POLL_INTERVAL:-30}"
|
||||
POLL_IDLE_MAX=${POLL_IDLE_MAX:-300}
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_state WINDOW FIELD VALUE
|
||||
# ---------------------------------------------------------------------------
|
||||
update_state() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --arg v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
update_state_int() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
agent_field() {
|
||||
jq -r --arg w "$1" --arg f "$2" \
|
||||
'.agents[] | select(.window == $w) | .[$f] // ""' \
|
||||
"$STATE_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_prompt() {
|
||||
local window="$1"
|
||||
for i in $(seq 1 60); do
|
||||
local cmd pane
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter; sleep 2; continue
|
||||
fi
|
||||
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0
|
||||
sleep 1
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_claude_idle WINDOW — wait up to 30s for Claude to reach idle ❯ prompt
|
||||
# (no spinner or busy indicator visible in the last 3 lines of pane output)
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_claude_idle() {
|
||||
local window="$1"
|
||||
local timeout="${2:-30}"
|
||||
local elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can scroll above last 3 lines.
|
||||
# Do NOT reset elapsed — resetting allows an infinite loop if the dialog never clears.
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
# Must be running under node (Claude is live)
|
||||
if [[ "$cmd" == "node" ]]; then
|
||||
# Idle: ❯ prompt visible AND no spinner/busy text in last 3 lines
|
||||
if echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
sleep 2
|
||||
(( elapsed += 2 ))
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_kick() {
|
||||
local window="$1" state="$2"
|
||||
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
|
||||
|
||||
local worktree_path session_id
|
||||
worktree_path=$(agent_field "$window" "worktree_path")
|
||||
session_id=$(agent_field "$window" "session_id")
|
||||
|
||||
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
|
||||
|
||||
# Wait for the shell prompt before typing — avoids sending into a still-draining pane
|
||||
wait_for_claude_idle "$window" 30 \
|
||||
|| echo "[$(date +%H:%M:%S)] KICK WARNING $window — pane still busy before resume, sending anyway"
|
||||
|
||||
# Resume the exact session so the agent retains full context — no need to re-send objective
|
||||
if [ -n "$session_id" ]; then
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
|
||||
else
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
|
||||
fi
|
||||
|
||||
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_approve WINDOW — auto-approve dialogs that need no judgment
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_approve() {
|
||||
local window="$1"
|
||||
local pane_tail
|
||||
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
|
||||
|
||||
# Settings error dialog at startup
|
||||
if echo "$pane_tail" | grep -q "Enter to confirm"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Numbered-option dialog (e.g. "Do you want to make this edit?")
|
||||
# ❯ is already on option 1 (Yes) — Enter confirms it
|
||||
if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
|
||||
tmux send-keys -t "$window" "" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# y/n prompt for safe operations
|
||||
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
|
||||
tmux send-keys -t "$window" "y" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Anything else — supervisor handles it, just log
|
||||
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll ${POLL_INTERVAL}s→${POLL_IDLE_MAX}s adaptive)"
|
||||
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
|
||||
echo "---"
|
||||
|
||||
while true; do
|
||||
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "[$(date +%H:%M:%S)] active=false — exiting."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
|
||||
KICKED=0; DONE=0
|
||||
|
||||
while IFS= read -r action; do
|
||||
[ -z "$action" ] && continue
|
||||
WINDOW=$(echo "$action" | jq -r '.window // ""')
|
||||
ACTION=$(echo "$action" | jq -r '.action // ""')
|
||||
STATE=$(echo "$action" | jq -r '.state // ""')
|
||||
|
||||
case "$ACTION" in
|
||||
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
|
||||
approve) handle_approve "$WINDOW" || true ;;
|
||||
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
|
||||
esac
|
||||
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
|
||||
|
||||
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
|
||||
"$STATE_FILE" 2>/dev/null || echo 0)
|
||||
|
||||
# Adaptive backoff: reset to base on activity or waiting_approval agents; back off when truly idle
|
||||
WAITING=$(jq '[.agents[] | select(.state == "waiting_approval")] | length' "$STATE_FILE" 2>/dev/null || echo 0)
|
||||
if (( KICKED > 0 || DONE > 0 || WAITING > 0 )); then
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
else
|
||||
POLL_CURRENT=$(( POLL_CURRENT + POLL_CURRENT / 2 + 1 ))
|
||||
(( POLL_CURRENT > POLL_IDLE_MAX )) && POLL_CURRENT=$POLL_IDLE_MAX
|
||||
fi
|
||||
|
||||
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled (next in ${POLL_CURRENT}s)"
|
||||
sleep "$POLL_CURRENT"
|
||||
done
|
||||
129
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
129
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env bash
|
||||
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
|
||||
#
|
||||
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
|
||||
# SESSION — tmux session name, e.g. autogpt1
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
|
||||
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
|
||||
# OBJECTIVE — task description sent to the agent
|
||||
# PR_NUMBER — (optional) GitHub PR number for completion verification
|
||||
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
|
||||
#
|
||||
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
|
||||
# Exit non-zero on failure.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SESSION="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
NEW_BRANCH="$4"
|
||||
OBJECTIVE="$5"
|
||||
PR_NUMBER="${6:-}"
|
||||
STEPS=("${@:7}")
|
||||
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# Generate a stable session ID so this agent's Claude session can always be resumed:
|
||||
# claude --resume $SESSION_ID --permission-mode bypassPermissions
|
||||
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
|
||||
|
||||
# Create (or switch to) the task branch
|
||||
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|
||||
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
|
||||
|
||||
# Open a new named tmux window; capture its numeric index
|
||||
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
|
||||
WINDOW="${SESSION}:${WIN_IDX}"
|
||||
|
||||
# Append the initial agent record to the state file so subsequent jq updates find it.
|
||||
# This must happen before the pr_number/steps update below.
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
NOW=$(date +%s)
|
||||
jq --arg window "$WINDOW" \
|
||||
--arg worktree "$WORKTREE_NAME" \
|
||||
--arg worktree_path "$WORKTREE_PATH" \
|
||||
--arg spare_branch "$SPARE_BRANCH" \
|
||||
--arg branch "$NEW_BRANCH" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg session_id "$SESSION_ID" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents += [{
|
||||
"window": $window,
|
||||
"worktree": $worktree,
|
||||
"worktree_path": $worktree_path,
|
||||
"spare_branch": $spare_branch,
|
||||
"branch": $branch,
|
||||
"objective": $objective,
|
||||
"session_id": $session_id,
|
||||
"state": "running",
|
||||
"checkpoints": [],
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": $now,
|
||||
"spawned_at": $now,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
|
||||
# The agent record was appended above so the jq select now finds it.
|
||||
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
|
||||
if [ "${#STEPS[@]}" -gt 0 ]; then
|
||||
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
|
||||
else
|
||||
STEPS_JSON='[]'
|
||||
fi
|
||||
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
|
||||
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Launch claude with a stable session ID so it can always be resumed after a crash:
|
||||
# claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
|
||||
|
||||
# wait_for_claude_idle — poll until the pane shows idle ❯ with no spinner in the last 3 lines.
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
_wait_idle() {
|
||||
local window="$1" timeout="${2:-60}" elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can appear above the last 3 lines
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
if [[ "$cmd" == "node" ]] && \
|
||||
echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
sleep 2; (( elapsed += 2 ))
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
# Wait up to 60s for claude to be fully interactive and idle (❯ visible, no spinner).
|
||||
if ! _wait_idle "$WINDOW" 60; then
|
||||
echo "[spawn-agent] WARNING: timed out waiting for idle ❯ prompt on $WINDOW — sending objective anyway" >&2
|
||||
fi
|
||||
|
||||
# Send the task. Split text and Enter — if combined, Enter can fire before the string
|
||||
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
|
||||
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# Only output the window address — nothing else (callers parse this)
|
||||
echo "$WINDOW"
|
||||
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# status.sh — print orchestrator status: state file summary + live tmux pane commands
|
||||
#
|
||||
# Usage: status.sh
|
||||
# Reads: ~/.claude/orchestrator-state.json
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "No orchestrator state found at $STATE_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Header: active status, session, thresholds, last poll
|
||||
jq -r '
|
||||
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
|
||||
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
|
||||
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
|
||||
""
|
||||
' "$STATE_FILE"
|
||||
|
||||
# Each agent: state, window, worktree/branch, truncated objective
|
||||
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
|
||||
if [ "$AGENT_COUNT" -eq 0 ]; then
|
||||
echo " (no agents registered)"
|
||||
else
|
||||
jq -r '
|
||||
.agents[] |
|
||||
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
|
||||
" \(.objective // "" | .[0:70])"
|
||||
' "$STATE_FILE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# Live pane_current_command for non-done agents
|
||||
while IFS= read -r WINDOW; do
|
||||
[ -z "$WINDOW" ] && continue
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
|
||||
echo " $WINDOW live: $CMD"
|
||||
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)
|
||||
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env bash
|
||||
# verify-complete.sh — verify a PR task is truly done before marking the agent done
|
||||
#
|
||||
# Check order matters:
|
||||
# 1. Checkpoints — did the agent do all required steps?
|
||||
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
|
||||
# 3. CI passing — no failures (agent must fix before done)
|
||||
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
|
||||
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
|
||||
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
|
||||
#
|
||||
# Usage: verify-complete.sh WINDOW
|
||||
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WINDOW="$1"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
|
||||
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
|
||||
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
|
||||
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
|
||||
|
||||
# No PR number = cannot verify
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 1: all required steps are checkpointed ---
|
||||
MISSING=""
|
||||
while IFS= read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
|
||||
MISSING="$MISSING $step"
|
||||
fi
|
||||
done <<< "$STEPS"
|
||||
|
||||
if [ -n "$MISSING" ]; then
|
||||
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve repo for all GitHub checks below
|
||||
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
|
||||
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
|
||||
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$REPO" ]; then
|
||||
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
|
||||
|
||||
# --- Check 2: CI fully complete — no pending checks ---
|
||||
# Pending checks MUST finish before we check threads/reviews:
|
||||
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
|
||||
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$PENDING" -gt 0 ]; then
|
||||
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 3: CI passing — no failures ---
|
||||
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$FAILING" -gt 0 ]; then
|
||||
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
|
||||
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
|
||||
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
|
||||
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
|
||||
if [ -n "$LATEST_RUN_AT" ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
|
||||
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
OWNER=$(echo "$REPO" | cut -d/ -f1)
|
||||
REPONAME=$(echo "$REPO" | cut -d/ -f2)
|
||||
|
||||
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
|
||||
UNRESOLVED=$(gh api graphql -f query="
|
||||
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
|
||||
pullRequest(number: ${PR_NUMBER}) {
|
||||
reviewThreads(first: 50) { nodes { isResolved } }
|
||||
}
|
||||
}
|
||||
}
|
||||
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
|
||||
|
||||
if [ "$UNRESOLVED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
|
||||
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
|
||||
# Stale reviews (pre-dating the fixing commits) should not block verification.
|
||||
#
|
||||
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
|
||||
# treat that as NOT COMPLETE rather than silently passing.
|
||||
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
|
||||
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
|
||||
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
|
||||
# PR activity (bot comments, label changes) which would create false negatives.
|
||||
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
|
||||
--json commits,latestReviews 2>/dev/null) || {
|
||||
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
|
||||
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
|
||||
|
||||
BLOCKING_CHANGES_REQUESTED=0
|
||||
BLOCKING_REQUESTERS=""
|
||||
|
||||
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
|
||||
while IFS= read -r review; do
|
||||
[ -z "$review" ] && continue
|
||||
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
|
||||
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
|
||||
if [ -z "$REVIEW_DATE" ]; then
|
||||
# No submission date — treat as fresh (conservative: blocks verification)
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
else
|
||||
if date --version >/dev/null 2>&1; then
|
||||
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
|
||||
# Review was submitted AFTER latest commit — still fresh, blocks verification
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
fi
|
||||
# Review submitted BEFORE latest commit — stale, skip
|
||||
fi
|
||||
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
|
||||
else
|
||||
# No commit date or no changes_requested — check raw count as fallback
|
||||
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
|
||||
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
|
||||
exit 0
|
||||
@@ -29,30 +29,83 @@ gh pr view {N} --json body --jq '.body'
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
> ⚠️ **WARNING — PAGINATE ALL PAGES BEFORE ADDRESSING ANYTHING**
|
||||
>
|
||||
> `reviewThreads(first: 100)` returns at most 100 threads per page AND returns threads **oldest-first**. On a PR with many review cycles (e.g. 373 threads), the oldest 100–200 threads are from past cycles and are **all already resolved**. Filtering client-side with `select(.isResolved == false)` on page 1 therefore yields **0 results** — even though pages 2–4 contain many unresolved threads from recent review cycles.
|
||||
>
|
||||
> **This is the most common failure mode:** agent fetches page 1, sees 0 unresolved after filtering, stops pagination, reports "done" — while hundreds of unresolved threads sit on later pages.
|
||||
>
|
||||
> One observed PR had 142 total threads: page 1 returned 0 unresolved (all old/resolved), while pages 2–3 had 111 unresolved. Another with 373 threads across 4 pages also had page 1 entirely resolved.
|
||||
>
|
||||
> **The rule: ALWAYS paginate to `hasNextPage == false` regardless of the per-page unresolved count. Never stop early because a page returns 0 unresolved.**
|
||||
|
||||
**Step 1 — Fetch total count and sanity-check the newest threads:**
|
||||
|
||||
```bash
|
||||
# Get total count and the newest 100 threads (last: 100 returns newest-first)
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
reviewThreads { totalCount }
|
||||
newest: reviewThreads(last: 100) {
|
||||
nodes { isResolved }
|
||||
}
|
||||
}
|
||||
}
|
||||
}' | jq '{ total: .data.repository.pullRequest.reviewThreads.totalCount, newest_unresolved: [.data.repository.pullRequest.newest.nodes[] | select(.isResolved == false)] | length }'
|
||||
```
|
||||
|
||||
If `total > 100`, you have multiple pages — you **must** paginate all of them regardless of what `newest_unresolved` shows. The `last: 100` check is a sanity signal only; the full loop below is mandatory.
|
||||
|
||||
**Step 2 — Collect all unresolved thread IDs across all pages:**
|
||||
|
||||
```bash
|
||||
# Accumulate all unresolved threads — loop until hasNextPage == false
|
||||
CURSOR=""
|
||||
ALL_THREADS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
line
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
}")
|
||||
# Append unresolved nodes from this page
|
||||
PAGE_THREADS=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)]')
|
||||
ALL_THREADS=$(echo "$ALL_THREADS $PAGE_THREADS" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
|
||||
# Reverse so newest threads (last pages) are addressed first — GitHub returns oldest-first
|
||||
# and the most recent review cycle's comments are the ones blocking approval.
|
||||
ALL_THREADS=$(echo "$ALL_THREADS" | jq 'reverse')
|
||||
|
||||
echo "Total unresolved threads: $(echo "$ALL_THREADS" | jq 'length')"
|
||||
echo "$ALL_THREADS" | jq '[.[] | {id, path, line, body: .comments.nodes[0].body[:200]}]'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
**Step 3 — Address every thread in `ALL_THREADS`, then resolve.**
|
||||
|
||||
Only after this loop completes (all pages fetched, count confirmed) should you begin making fixes.
|
||||
|
||||
> **Why reverse?** GraphQL returns threads oldest-first and exposes no `orderBy` option. A PR with 373 threads has ~4 pages; threads from the latest review cycle land on the last pages. Processing in reverse ensures the newest, most blocking comments are addressed first — the earlier pages mostly contain outdated threads from prior cycles.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
@@ -84,16 +137,65 @@ Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`gi
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
**CRITICAL: The only valid sequence is fix → commit → push → reply → resolve. Never resolve a thread without a real code commit.**
|
||||
|
||||
Resolving a thread via `resolveReviewThread` without an actual fix is the most common failure mode — it makes unresolved counts drop without any real change, producing a false "done" signal. If the issue was genuinely a false positive (no code change needed), reply explaining why and then resolve. Otherwise:
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → resolve.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
Use a **markdown commit link** so GitHub renders it as a clickable reference. Always get the full SHA with `git rev-parse HEAD` **after** committing — never copy a SHA from a previous commit or hardcode one:
|
||||
|
||||
```bash
|
||||
FULL_SHA=$(git rev-parse HEAD)
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): <description>"
|
||||
```
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
|
||||
### What counts as a valid resolution
|
||||
|
||||
Only two situations justify calling `resolveReviewThread`:
|
||||
|
||||
1. **Real code fix**: you changed the code, committed + pushed, and replied with the SHA. The commit diff must actually address the concern — not just touch the same file.
|
||||
2. **Genuine false positive**: the reviewer's concern does not apply to this code, and you can give a specific technical reason (e.g. "Not applicable — `sdk_cwd` is pre-validated by `_make_sdk_cwd()` which applies normpath + prefix assertion before reaching this point").
|
||||
|
||||
**Anti-patterns that look resolved but aren't — never do these:**
|
||||
- `"Accepted, tracked as follow-up"` — a deferral, not a fix. The concern is still open. Do not resolve.
|
||||
- `"Acknowledged"` or `"Same as above"` — these are acknowledgements, not fixes. Do not resolve.
|
||||
- `"Fixed in abc1234"` where `abc1234` is a commit that doesn't actually change the flagged line/logic — dishonest. Verify `git show abc1234 -- path/to/file` changes the right thing before posting.
|
||||
- Resolving without replying — the reviewer never sees what happened.
|
||||
|
||||
When in doubt: if a code change is needed, make it. A deferred issue means the thread stays open until the follow-up PR is merged.
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
|
||||
|
||||
### Running coverage locally
|
||||
|
||||
**Backend** (from `autogpt_platform/backend/`):
|
||||
```bash
|
||||
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
|
||||
```
|
||||
|
||||
**Frontend** (from `autogpt_platform/frontend/`):
|
||||
```bash
|
||||
pnpm vitest run --coverage
|
||||
```
|
||||
|
||||
### When codecov/patch fails
|
||||
|
||||
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
|
||||
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
|
||||
3. Run coverage locally to verify, commit, push.
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -119,6 +221,22 @@ Then commit and **push immediately** — never batch commits without pushing. Ea
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## Coverage
|
||||
|
||||
Codecov enforces patch coverage on new/changed lines — new code you write must be tested. Before pushing, verify you haven't left new lines uncovered:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run pytest --cov=. --cov-report=term-missing {path/to/changed/module}
|
||||
```
|
||||
|
||||
Look for lines marked `miss` — those are uncovered. Add tests for any new code you wrote as part of addressing comments.
|
||||
|
||||
**Rules:**
|
||||
- New code you add should have tests
|
||||
- Don't remove existing tests when fixing comments
|
||||
- If a reviewer asks you to delete code, also delete its tests, but verify coverage hasn't dropped on remaining lines
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
@@ -208,3 +326,113 @@ git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Two distinct rate limits exist — they have different causes and recovery times:
|
||||
|
||||
| Error | HTTP code | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
|
||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||
|
||||
**Recovery from secondary rate limit (403):**
|
||||
1. Stop all API writes immediately
|
||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||
3. Resume with `sleep 3` between each call
|
||||
4. If 403 persists after 2 min, wait another 2 min before retrying
|
||||
|
||||
Never batch all replies in a tight loop — always space them out.
|
||||
|
||||
## Parallel thread resolution
|
||||
|
||||
When a PR has more than 10 unresolved threads, addressing one commit per thread is slow. Use this strategy instead:
|
||||
|
||||
### Group by file, batch per commit
|
||||
|
||||
1. Sort `ALL_THREADS` by `path` — threads in the same file can share a single commit.
|
||||
2. Fix all threads in one file → `git commit` → `git push` → reply to **all** those threads with the same SHA → resolve them all.
|
||||
3. Move to the next file group and repeat.
|
||||
|
||||
This reduces N commits to (number of files touched), which is usually 3–5 instead of 15–30.
|
||||
|
||||
### Posting replies concurrently (for large batches)
|
||||
|
||||
For truly independent thread groups (different files, no shared logic), you can post replies in parallel using background subshells — but always space out API writes:
|
||||
|
||||
```bash
|
||||
# Post replies to a batch of threads concurrently, 3s apart
|
||||
(
|
||||
sleep 3
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID1}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
(
|
||||
sleep 6
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID2}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
wait # wait for all background replies before resolving
|
||||
```
|
||||
|
||||
Then resolve sequentially (GraphQL mutations):
|
||||
```bash
|
||||
for THREAD_ID in "$THREAD1" "$THREAD2" "$THREAD3"; do
|
||||
gh api graphql -f query="mutation { resolveReviewThread(input: {threadId: \"${THREAD_ID}\"}) { thread { isResolved } } }"
|
||||
sleep 3
|
||||
done
|
||||
```
|
||||
|
||||
**Always sleep 3s between individual API writes** — GitHub's secondary rate limit (403) triggers on bursts of >20 writes. Increase to `sleep 5` when posting more than 20 replies in a batch.
|
||||
|
||||
## Resolving threads via GraphQL
|
||||
|
||||
Use `resolveReviewThread` **only after** the commit is pushed and the reply is posted:
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREAD_ID"}) { thread { isResolved } } }'
|
||||
```
|
||||
|
||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||
|
||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||
|
||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||
|
||||
```bash
|
||||
# Step 1: get total thread count
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads { totalCount }
|
||||
}
|
||||
}
|
||||
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
|
||||
|
||||
# Step 2: paginate all pages, count truly unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes { isResolved }
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved threads: $UNRESOLVED"
|
||||
```
|
||||
|
||||
Only output `ORCHESTRATOR:DONE` after this loop reports 0.
|
||||
|
||||
@@ -310,6 +310,28 @@ TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
### 3i. Disable onboarding for test user
|
||||
|
||||
The frontend redirects to `/onboarding` when the `VISIT_COPILOT` step is not in `completedSteps`.
|
||||
Mark it complete via the backend API so every browser test lands on the real feature UI:
|
||||
|
||||
```bash
|
||||
ONBOARDING_RESULT=$(curl -s --max-time 30 -X POST \
|
||||
"http://localhost:8006/api/onboarding/step?step=VISIT_COPILOT" \
|
||||
-H "Authorization: Bearer $TOKEN")
|
||||
echo "Onboarding bypass: $ONBOARDING_RESULT"
|
||||
|
||||
# Verify it took effect
|
||||
ONBOARDING_STATUS=$(curl -s --max-time 30 \
|
||||
"http://localhost:8006/api/onboarding/completed" \
|
||||
-H "Authorization: Bearer $TOKEN" | jq -r '.is_completed')
|
||||
echo "Onboarding completed: $ONBOARDING_STATUS"
|
||||
if [ "$ONBOARDING_STATUS" != "true" ]; then
|
||||
echo "ERROR: onboarding bypass failed — browser tests will hit /onboarding instead of the target feature. Investigate before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
@@ -547,6 +569,8 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
@@ -584,15 +608,27 @@ TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Resolve parent commit so screenshots are chained, not orphan root commits
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
|
||||
if [ -n "$PARENT_SHA" ]; then
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
-f "parents[]=$PARENT_SHA" \
|
||||
--jq '.sha')
|
||||
else
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
fi
|
||||
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
-X PATCH -f sha="$COMMIT_SHA" -F force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
@@ -658,6 +694,15 @@ INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
|
||||
# Verify the posted comment contains inline images — exit 1 if none found
|
||||
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
|
||||
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
|
||||
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
|
||||
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Inline images verified in posted comment"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
@@ -667,6 +712,103 @@ rm -f "$COMMENT_FILE"
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Step 8: Evaluate and post a formal PR review
|
||||
|
||||
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
|
||||
|
||||
### Evaluation criteria
|
||||
|
||||
Re-read the PR description:
|
||||
```bash
|
||||
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
|
||||
```
|
||||
|
||||
Score the run against each criterion:
|
||||
|
||||
| Criterion | Pass condition |
|
||||
|-----------|---------------|
|
||||
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
|
||||
| **All scenarios pass** | No FAIL rows in the results table |
|
||||
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
|
||||
| **Before/after evidence** | Every state-changing API call has before/after values logged |
|
||||
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
|
||||
| **No regressions** | Existing core flows (login, agent create/run) still work |
|
||||
|
||||
### Decision logic
|
||||
|
||||
```
|
||||
ALL criteria pass → APPROVE
|
||||
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
|
||||
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
|
||||
```
|
||||
|
||||
### Post the review
|
||||
|
||||
```bash
|
||||
REVIEW_FILE=$(mktemp)
|
||||
|
||||
# Count results
|
||||
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
|
||||
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
|
||||
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
|
||||
|
||||
# List any coverage gaps found during evaluation (populate this array as you assess)
|
||||
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
|
||||
COVERAGE_GAPS=()
|
||||
```
|
||||
|
||||
**If APPROVING** — all criteria met, zero failures, full coverage:
|
||||
|
||||
```bash
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — APPROVED
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
|
||||
|
||||
**Coverage:** All features described in the PR were exercised.
|
||||
|
||||
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
|
||||
|
||||
**Negative tests:** Failure paths tested for each feature.
|
||||
|
||||
No regressions observed on core flows.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
|
||||
echo "✅ PR approved"
|
||||
```
|
||||
|
||||
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
|
||||
|
||||
```bash
|
||||
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
|
||||
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — Changes Requested
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
|
||||
|
||||
### Required before merge
|
||||
|
||||
${FAIL_LIST}
|
||||
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
|
||||
|
||||
Please fix the above and re-run the E2E tests.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
|
||||
echo "❌ Changes requested"
|
||||
```
|
||||
|
||||
```bash
|
||||
rm -f "$REVIEW_FILE"
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
|
||||
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
|
||||
- Never request changes for issues already fixed in this run
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
225
.claude/skills/write-frontend-tests/SKILL.md
Normal file
225
.claude/skills/write-frontend-tests/SKILL.md
Normal file
@@ -0,0 +1,225 @@
|
||||
---
|
||||
name: write-frontend-tests
|
||||
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
|
||||
user-invocable: true
|
||||
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Write Frontend Tests
|
||||
|
||||
Analyze the current branch's frontend changes, plan integration tests, and write them.
|
||||
|
||||
## References
|
||||
|
||||
Before writing any tests, read the testing rules and conventions:
|
||||
|
||||
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
|
||||
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
|
||||
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
|
||||
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
|
||||
|
||||
## Step 1: Identify changed frontend files
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${ARGUMENTS:-dev}"
|
||||
cd autogpt_platform/frontend
|
||||
|
||||
# Get changed frontend files (excluding generated, config, and test files)
|
||||
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
|
||||
| grep -v '__generated__' \
|
||||
| grep -v '__tests__' \
|
||||
| grep -v '\.test\.' \
|
||||
| grep -v '\.stories\.' \
|
||||
| grep -v '\.spec\.'
|
||||
```
|
||||
|
||||
Also read the diff to understand what changed:
|
||||
|
||||
```bash
|
||||
git diff "$BASE_BRANCH"...HEAD --stat -- src/
|
||||
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
```
|
||||
|
||||
## Step 2: Categorize changes and find test targets
|
||||
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Shared hooks with standalone business logic when UI-level coverage is impractical
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
|
||||
## Step 3: Check for existing tests
|
||||
|
||||
For each test target, check if tests already exist:
|
||||
|
||||
```bash
|
||||
# For a page at src/app/(platform)/library/page.tsx
|
||||
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
|
||||
|
||||
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
|
||||
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
|
||||
```
|
||||
|
||||
Note which targets have no tests (need new files) vs which have tests that need updating.
|
||||
|
||||
## Step 4: Identify API endpoints used
|
||||
|
||||
For each test target, find which API hooks are used:
|
||||
|
||||
```bash
|
||||
# Find generated API hook imports in the changed files
|
||||
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
|
||||
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
|
||||
```
|
||||
|
||||
For each API hook found, locate the corresponding MSW handler:
|
||||
|
||||
```bash
|
||||
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
|
||||
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
|
||||
```
|
||||
|
||||
List every MSW handler you will need (200 for happy path, 4xx for error paths).
|
||||
|
||||
## Step 5: Write the test plan
|
||||
|
||||
Before writing code, output a plan as a numbered list:
|
||||
|
||||
```
|
||||
Test plan for [branch name]:
|
||||
|
||||
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
|
||||
- Renders page with agent list (MSW 200)
|
||||
- Shows loading state
|
||||
- Shows error state (MSW 422)
|
||||
- Handles empty agent list
|
||||
|
||||
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
|
||||
- Filters agents by search query
|
||||
- Shows no results message
|
||||
- Clears search
|
||||
|
||||
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
|
||||
- Add test for new "duplicate" action
|
||||
```
|
||||
|
||||
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
|
||||
|
||||
## Step 6: Write the tests
|
||||
|
||||
For each test file in the plan, follow these conventions:
|
||||
|
||||
### File structure
|
||||
|
||||
```tsx
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
// Import MSW handlers for endpoints the page uses
|
||||
import {
|
||||
getGetV2ListLibraryAgentsMockHandler200,
|
||||
getGetV2ListLibraryAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/library/library.msw";
|
||||
// Import the component under test
|
||||
import LibraryPage from "../page";
|
||||
|
||||
describe("LibraryPage", () => {
|
||||
test("renders agent list from API", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler200());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/my agents/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows error state on API failure", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler422());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/error/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Rules
|
||||
|
||||
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
|
||||
- Use `server.use()` to set up MSW handlers BEFORE rendering
|
||||
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
|
||||
- Use `getBy*` only for elements that are immediately present in the DOM
|
||||
- Use `screen` queries — do NOT destructure from `render()`
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
|
||||
### Test location
|
||||
|
||||
```
|
||||
# For pages: __tests__/ next to page.tsx
|
||||
src/app/(platform)/library/__tests__/main.test.tsx
|
||||
|
||||
# For complex standalone components: __tests__/ inside component folder
|
||||
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
|
||||
|
||||
# For pure helpers: co-located .test.ts
|
||||
src/app/(platform)/library/helpers.test.ts
|
||||
```
|
||||
|
||||
### Custom MSW overrides
|
||||
|
||||
When the auto-generated faker data is not enough, override with specific data:
|
||||
|
||||
```tsx
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
);
|
||||
```
|
||||
|
||||
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
|
||||
|
||||
## Step 7: Run and verify
|
||||
|
||||
After writing all tests:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
Then run the full checks:
|
||||
|
||||
```bash
|
||||
pnpm format
|
||||
pnpm lint
|
||||
pnpm types
|
||||
```
|
||||
78
.github/workflows/classic-autogpt-ci.yml
vendored
78
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,19 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +27,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +54,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +81,13 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +101,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
tests/unit tests/integration 2>&1 | tee test_output.txt
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
|
||||
|
||||
test_failure=${PIPESTATUS[0]}
|
||||
|
||||
|
||||
44
.github/workflows/classic-autogpts-ci.yml
vendored
44
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -10,10 +10,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -21,10 +20,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +33,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +49,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
256
.github/workflows/classic-benchmark-ci.yml
vendored
256
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,18 +1,24 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -23,95 +29,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -124,53 +51,120 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
189
.github/workflows/classic-forge-ci.yml
vendored
189
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,131 +23,60 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: poetry run playwright install chromium
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +90,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -269,12 +269,14 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
fi
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
@@ -287,11 +289,13 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-backend
|
||||
files: ./autogpt_platform/backend/coverage.xml
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,3 +148,11 @@ jobs:
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend
|
||||
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml
|
||||
|
||||
38
.github/workflows/platform-fullstack-ci.yml
vendored
38
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -179,21 +180,30 @@ jobs:
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
|
||||
# (docker compose config on some versions drops this arg)
|
||||
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
|
||||
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
|
||||
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
|
||||
fi
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
NEXT_PUBLIC_SOURCEMAPS: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
@@ -279,16 +289,38 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Cache Playwright browsers
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
playwright-${{ runner.os }}-
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
- name: Run Playwright E2E suite
|
||||
run: pnpm test:e2e:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend-e2e
|
||||
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
|
||||
disable_search: true
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -16,6 +17,7 @@ log-ingestion.txt
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
!autogpt_platform/frontend/public/notification.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,9 +183,15 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
/autogpt_platform/backend/poetry.toml
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
|
||||
36
.gitleaks.toml
Normal file
36
.gitleaks.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
title = "AutoGPT Gitleaks Config"
|
||||
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[allowlist]
|
||||
description = "Global allowlist"
|
||||
paths = [
|
||||
# Template/example env files (no real secrets)
|
||||
'''\.env\.(default|example|template)$''',
|
||||
# Lock files
|
||||
'''pnpm-lock\.yaml$''',
|
||||
'''poetry\.lock$''',
|
||||
# Secrets baseline
|
||||
'''\.secrets\.baseline$''',
|
||||
# Build artifacts and caches (should not be committed)
|
||||
'''__pycache__/''',
|
||||
'''classic/frontend/build/''',
|
||||
# Docker dev setup (local dev JWTs/keys only)
|
||||
'''autogpt_platform/db/docker/''',
|
||||
# Load test configs (dev JWTs)
|
||||
'''load-tests/configs/''',
|
||||
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
|
||||
'''(_test|test_.*|conftest)\.py$''',
|
||||
# Documentation (only contains placeholder keys in curl/API examples)
|
||||
'''docs/.*\.md$''',
|
||||
# Firebase config (public API keys by design)
|
||||
'''google-services\.json$''',
|
||||
'''classic/frontend/(lib|web)/''',
|
||||
]
|
||||
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
|
||||
regexes = [
|
||||
'''dvziYgz0KSK8FENhju0ZYi8''',
|
||||
# LLM model name enum values falsely flagged as API keys
|
||||
'''Llama-\d.*Instruct''',
|
||||
]
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -23,9 +23,15 @@ repos:
|
||||
- id: detect-secrets
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
args: ["--baseline", ".secrets.baseline"]
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
name: Detect secrets (gitleaks)
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -84,51 +90,16 @@ repos:
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
|
||||
poetry -C classic install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
@@ -223,26 +194,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -256,26 +211,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -311,29 +253,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -360,26 +283,9 @@ repos:
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# name: Run tests - Classic (excl. slow tests)
|
||||
# alias: pytest-classic
|
||||
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
|
||||
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
471
.secrets.baseline
Normal file
471
.secrets.baseline
Normal file
@@ -0,0 +1,471 @@
|
||||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitLabTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "IPPublicDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "OpenAIDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "PypiTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TelegramBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_baseline_file",
|
||||
"filename": ".secrets.baseline"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.regex.should_exclude_file",
|
||||
"pattern": [
|
||||
"\\.env$",
|
||||
"pnpm-lock\\.yaml$",
|
||||
"\\.env\\.(default|example|template)$",
|
||||
"__pycache__",
|
||||
"_test\\.py$",
|
||||
"test_.*\\.py$",
|
||||
"conftest\\.py$",
|
||||
"poetry\\.lock$",
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
],
|
||||
"results": {
|
||||
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
|
||||
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
|
||||
"is_verified": false,
|
||||
"line_number": 289
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
|
||||
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
|
||||
"is_verified": false,
|
||||
"line_number": 29
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
|
||||
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
|
||||
"is_verified": false,
|
||||
"line_number": 12
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/checks.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 108
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/ci.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
|
||||
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
|
||||
"is_verified": false,
|
||||
"line_number": 123
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
|
||||
"is_verified": false,
|
||||
"line_number": 42
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
|
||||
"is_verified": false,
|
||||
"line_number": 344
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
|
||||
"is_verified": false,
|
||||
"line_number": 534
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/docs.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
|
||||
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
|
||||
"is_verified": false,
|
||||
"line_number": 203
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
|
||||
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
|
||||
"is_verified": false,
|
||||
"line_number": 57
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
|
||||
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
|
||||
"is_verified": false,
|
||||
"line_number": 53
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/medium.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
|
||||
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
|
||||
"is_verified": false,
|
||||
"line_number": 131
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
|
||||
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
|
||||
"is_verified": false,
|
||||
"line_number": 100
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/talking_head.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
|
||||
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
|
||||
"is_verified": false,
|
||||
"line_number": 113
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
|
||||
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
|
||||
"is_verified": false,
|
||||
"line_number": 17
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/util/cache.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/util/cache.py",
|
||||
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
|
||||
"is_verified": false,
|
||||
"line_number": 449
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
|
||||
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
|
||||
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 8
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 7
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 192
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
|
||||
"is_verified": false,
|
||||
"line_number": 102
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 103
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
|
||||
"is_verified": false,
|
||||
"line_number": 73
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
|
||||
"is_verified": false,
|
||||
"line_number": 75
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
|
||||
"is_verified": false,
|
||||
"line_number": 77
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
|
||||
"is_verified": false,
|
||||
"line_number": 79
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
|
||||
"is_verified": false,
|
||||
"line_number": 81
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
|
||||
"is_verified": false,
|
||||
"line_number": 83
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/constants.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
|
||||
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
|
||||
"is_verified": false,
|
||||
"line_number": 13
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
|
||||
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
|
||||
"is_verified": false,
|
||||
"line_number": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-09T14:20:23Z"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
@@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
|
||||
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
|
||||
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
120
autogpt_platform/AGENTS.md
Normal file
120
autogpt_platform/AGENTS.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1,120 +1 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@AGENTS.md
|
||||
|
||||
100
autogpt_platform/analytics/queries/platform_cost_log.sql
Normal file
100
autogpt_platform/analytics/queries/platform_cost_log.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.platform_cost_log
|
||||
-- Looker source alias: ds115 | Charts: 0
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per platform cost log entry (last 90 days).
|
||||
-- Tracks real API spend at the call level: provider, model,
|
||||
-- token counts (including Anthropic cache tokens), cost in
|
||||
-- microdollars, and the block/execution that incurred the cost.
|
||||
-- Joins the User table to provide email for per-user breakdowns.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.PlatformCostLog — Per-call cost records
|
||||
-- platform.User — User email
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Log entry UUID
|
||||
-- createdAt TIMESTAMPTZ When the cost was recorded
|
||||
-- userId TEXT User who incurred the cost (nullable)
|
||||
-- email TEXT User email (nullable)
|
||||
-- graphExecId TEXT Graph execution UUID (nullable)
|
||||
-- nodeExecId TEXT Node execution UUID (nullable)
|
||||
-- blockName TEXT Block that made the API call (nullable)
|
||||
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
|
||||
-- model TEXT Model name (nullable)
|
||||
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
|
||||
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
|
||||
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
|
||||
-- inputTokens INT Prompt/input tokens (nullable)
|
||||
-- outputTokens INT Completion/output tokens (nullable)
|
||||
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
|
||||
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
|
||||
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
|
||||
-- duration FLOAT API call duration in seconds (nullable)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend by provider (last 90 days)
|
||||
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Spend by model
|
||||
-- SELECT provider, model, SUM("costUsd") AS total_usd,
|
||||
-- SUM("inputTokens") AS input_tokens,
|
||||
-- SUM("outputTokens") AS output_tokens
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE model IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Top 20 users by spend
|
||||
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE "userId" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
|
||||
--
|
||||
-- -- Daily spend trend
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("costUsd") AS daily_usd,
|
||||
-- COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("cacheReadTokens")::float /
|
||||
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE provider = 'anthropic'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
p."id" AS id,
|
||||
p."createdAt" AS createdAt,
|
||||
p."userId" AS userId,
|
||||
u."email" AS email,
|
||||
p."graphExecId" AS graphExecId,
|
||||
p."nodeExecId" AS nodeExecId,
|
||||
p."blockName" AS blockName,
|
||||
p."provider" AS provider,
|
||||
p."model" AS model,
|
||||
p."trackingType" AS trackingType,
|
||||
p."costMicrodollars" AS costMicrodollars,
|
||||
p."costMicrodollars"::float / 1000000.0 AS costUsd,
|
||||
p."inputTokens" AS inputTokens,
|
||||
p."outputTokens" AS outputTokens,
|
||||
p."cacheReadTokens" AS cacheReadTokens,
|
||||
p."cacheCreationTokens" AS cacheCreationTokens,
|
||||
CASE
|
||||
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
|
||||
THEN p."inputTokens" + p."outputTokens"
|
||||
ELSE NULL
|
||||
END AS totalTokens,
|
||||
p."duration" AS duration
|
||||
FROM platform."PlatformCostLog" p
|
||||
LEFT JOIN platform."User" u ON u."id" = p."userId"
|
||||
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -58,6 +58,17 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Graphiti Temporal Knowledge Graph Memory
|
||||
# Rollout controlled by LaunchDarkly flag "graphiti-memory"
|
||||
# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY.
|
||||
# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY.
|
||||
GRAPHITI_FALKORDB_HOST=localhost
|
||||
GRAPHITI_FALKORDB_PORT=6380
|
||||
GRAPHITI_FALKORDB_PASSWORD=
|
||||
GRAPHITI_LLM_MODEL=gpt-4.1-mini
|
||||
GRAPHITI_EMBEDDER_MODEL=text-embedding-3-small
|
||||
GRAPHITI_SEMAPHORE_LIMIT=5
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
@@ -178,6 +189,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
227
autogpt_platform/backend/AGENTS.md
Normal file
227
autogpt_platform/backend/AGENTS.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1,227 +1 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@AGENTS.md
|
||||
|
||||
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
@@ -0,0 +1,166 @@
|
||||
{
|
||||
"id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"version": 2,
|
||||
"is_active": true,
|
||||
"name": "Calculator agent",
|
||||
"description": "",
|
||||
"instructions": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"user_id": "",
|
||||
"created_at": "2026-04-13T03:45:11.241Z",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"input_default": {
|
||||
"name": "Input",
|
||||
"secret": false,
|
||||
"advanced": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": -188.2244873046875,
|
||||
"y": 95
|
||||
}
|
||||
},
|
||||
"input_links": [],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
"input_default": {
|
||||
"name": "Output",
|
||||
"secret": false,
|
||||
"advanced": false,
|
||||
"escape_html": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 825.198974609375,
|
||||
"y": 123.75
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"output_links": [],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
"input_default": {
|
||||
"b": 34,
|
||||
"operation": "Add",
|
||||
"round_result": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 323.0255126953125,
|
||||
"y": 121.25
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
},
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"sub_graphs": [],
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Input": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Input"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Input"
|
||||
]
|
||||
},
|
||||
"output_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Output": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Output"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Output"
|
||||
]
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,10 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -618,6 +621,11 @@ async def delete_credential(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
get_platform_cost_logs_for_export,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
return await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostExportResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
total_rows: int
|
||||
truncated: bool
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs/export",
|
||||
response_model=PlatformCostExportResponse,
|
||||
summary="Export Platform Cost Logs",
|
||||
)
|
||||
async def export_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s exporting platform cost logs", admin_user_id)
|
||||
logs, truncated = await get_platform_cost_logs_for_export(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
return PlatformCostExportResponse(
|
||||
logs=logs,
|
||||
total_rows=len(logs),
|
||||
truncated=truncated,
|
||||
)
|
||||
@@ -0,0 +1,291 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import CostLogRow, PlatformCostDashboard
|
||||
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_repeated_requests(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Repeated requests to the dashboard route both return 200."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
r1 = client.get("/platform-costs/dashboard")
|
||||
r2 = client.get("/platform-costs/dashboard")
|
||||
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r1.json()["total_cost_microdollars"] == 42
|
||||
assert r2.json()["total_cost_microdollars"] == 42
|
||||
|
||||
|
||||
def _make_cost_log_row() -> CostLogRow:
|
||||
return CostLogRow(
|
||||
id="log-1",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
user_id="user-1",
|
||||
email="u***@example.com",
|
||||
graph_exec_id="graph-1",
|
||||
node_exec_id="node-1",
|
||||
block_name="LlmCallBlock",
|
||||
provider="anthropic",
|
||||
tracking_type="token",
|
||||
cost_microdollars=500,
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cache_read_tokens=10,
|
||||
cache_creation_tokens=5,
|
||||
duration=1.5,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
)
|
||||
|
||||
|
||||
def test_export_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
row = _make_cost_log_row()
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
AsyncMock(return_value=([row], False)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_rows"] == 1
|
||||
assert data["truncated"] is False
|
||||
assert len(data["logs"]) == 1
|
||||
assert data["logs"][0]["cache_read_tokens"] == 10
|
||||
assert data["logs"][0]["cache_creation_tokens"] == 5
|
||||
|
||||
|
||||
def test_export_logs_truncated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
rows = [_make_cost_log_row() for _ in range(3)]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
AsyncMock(return_value=(rows, True)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_rows"] == 3
|
||||
assert data["truncated"] is True
|
||||
|
||||
|
||||
def test_export_logs_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_export = AsyncMock(return_value=([], False))
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
mock_export,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs/export",
|
||||
params={
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"block_name": "LlmCallBlock",
|
||||
"tracking_type": "token",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_export.assert_called_once()
|
||||
call_kwargs = mock_export.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "anthropic"
|
||||
assert call_kwargs["model"] == "claude-3-5-sonnet-20241022"
|
||||
assert call_kwargs["block_name"] == "LlmCallBlock"
|
||||
assert call_kwargs["tracking_type"] == "token"
|
||||
|
||||
|
||||
def test_export_logs_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
@@ -9,11 +9,14 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
@@ -143,4 +158,102 @@ async def reset_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s",
|
||||
request.user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved_email = None
|
||||
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
|
||||
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email,
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -89,6 +89,7 @@ def test_get_rate_limit(
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure returns 404 (user unverifiable)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -11,15 +11,17 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -40,6 +42,7 @@ from backend.copilot.rate_limit import (
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.service import strip_injected_context_for_display
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
@@ -58,6 +61,10 @@ from backend.copilot.tools.models import (
|
||||
InputValidationErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
MemoryForgetCandidatesResponse,
|
||||
MemoryForgetConfirmResponse,
|
||||
MemorySearchResponse,
|
||||
MemoryStoreResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
@@ -98,6 +105,28 @@ router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
|
||||
def _strip_injected_context(message: dict) -> dict:
|
||||
"""Hide server-injected context blocks from the API response.
|
||||
|
||||
Returns a **shallow copy** of *message* with all server-injected XML
|
||||
blocks removed from ``content`` (if applicable). The original dict is
|
||||
never mutated, so callers can safely pass live session dicts without
|
||||
risking side-effects.
|
||||
|
||||
Handles all three injected block types — ``<memory_context>``,
|
||||
``<env_context>``, and ``<user_context>`` — regardless of the order they
|
||||
appear at the start of the message. Only ``user``-role messages with
|
||||
string content are touched; assistant / multimodal blocks pass through
|
||||
unchanged.
|
||||
"""
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), str):
|
||||
result = message.copy()
|
||||
result["content"] = strip_injected_context_for_display(message["content"])
|
||||
return result
|
||||
return message
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
@@ -110,6 +139,28 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
model: CopilotLlmModel | None = Field(
|
||||
default=None,
|
||||
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
|
||||
"If None, the server applies per-user LD targeting then falls back to config.",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -118,6 +169,7 @@ class CreateSessionResponse(BaseModel):
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -136,8 +188,11 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -248,6 +303,7 @@ async def list_sessions(
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -256,22 +312,28 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -324,6 +386,31 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}/stream",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
)
|
||||
async def disconnect_session_stream(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""Disconnect all active SSE listeners for a session.
|
||||
|
||||
Called by the frontend when the user switches away from a chat so the
|
||||
backend releases XREAD listeners immediately rather than waiting for
|
||||
the 5-10 s timeout.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
await stream_registry.disconnect_all_listeners(session_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
@@ -367,59 +454,67 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
messages = [
|
||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||
]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in page.session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=page.session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -433,8 +528,9 @@ async def get_copilot_usage(
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
@@ -442,6 +538,7 @@ async def get_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -493,7 +590,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -527,10 +624,13 @@ async def reset_copilot_usage(
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -606,6 +706,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -716,7 +817,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
@@ -761,57 +862,66 @@ async def stream_chat_post(
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
# saved yet. append_and_save_message returns None when a duplicate is
|
||||
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||
is_duplicate_message = False
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
is_duplicate_message = (
|
||||
await append_and_save_message(session_id, message)
|
||||
) is None
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if not is_duplicate_message and request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
# Create a task in the stream registry for reconnection support.
|
||||
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||
# in-progress executor output rather than an empty stream.
|
||||
turn_id = ""
|
||||
if not is_duplicate_message:
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -819,6 +929,9 @@ async def stream_chat_post(
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
@@ -843,7 +956,6 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
@@ -873,7 +985,6 @@ async def stream_chat_post(
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
@@ -888,6 +999,7 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
@@ -902,7 +1014,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -1174,7 +1285,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
@@ -1208,6 +1319,10 @@ ToolResponseUnion = (
|
||||
| DocPageResponse
|
||||
| MCPToolsDiscoveredResponse
|
||||
| MCPToolOutputResponse
|
||||
| MemoryStoreResponse
|
||||
| MemorySearchResponse
|
||||
| MemoryForgetCandidatesResponse
|
||||
| MemoryForgetConfirmResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.api.features.chat.routes import _strip_injected_context
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -131,16 +133,23 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
validation and enrichment logic without needing RabbitMQ.
|
||||
|
||||
Returns:
|
||||
A namespace with ``save`` and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
@@ -148,7 +157,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_enqueue = mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
@@ -156,9 +165,12 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
return types.SimpleNamespace(
|
||||
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
@@ -184,10 +196,33 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── Duplicate message dedup ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_skips_enqueue_for_duplicate_message(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When append_and_save_message returns None (duplicate detected),
|
||||
enqueue_copilot_turn and stream_registry.create_session must NOT be called
|
||||
to avoid double-processing and to prevent overwriting the active stream's
|
||||
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
|
||||
mocks = _mock_stream_internals(mocker)
|
||||
# Override save to return None — signalling a duplicate
|
||||
mocks.save.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mocks.enqueue.assert_not_called()
|
||||
mocks.registry.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
@@ -226,7 +261,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
@@ -255,7 +290,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -276,7 +311,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -299,7 +336,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -331,14 +368,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -369,6 +420,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -376,11 +428,9 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
@@ -391,6 +441,7 @@ def test_usage_uses_config_limits(
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -469,3 +520,296 @@ def test_suggested_prompts_empty_prompts(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
class TestStripInjectedContext:
|
||||
"""Unit tests for `_strip_injected_context` — the GET-side helper that
|
||||
hides the server-injected `<user_context>` block from API responses.
|
||||
|
||||
The strip is intentionally exact-match: it only removes the prefix the
|
||||
inject helper writes (`<user_context>...</user_context>\\n\\n` at the very
|
||||
start of the message). Any drift between writer and reader leaves the raw
|
||||
block visible in the chat history, which is the failure mode this suite
|
||||
documents.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _msg(role: str, content):
|
||||
return {"role": role, "content": content}
|
||||
|
||||
def test_strips_well_formed_prefix(self) -> None:
|
||||
|
||||
original = "<user_context>\nbiz ctx\n</user_context>\n\nhello world"
|
||||
result = _strip_injected_context(self._msg("user", original))
|
||||
assert result["content"] == "hello world"
|
||||
|
||||
def test_passes_through_message_without_prefix(self) -> None:
|
||||
|
||||
result = _strip_injected_context(self._msg("user", "just a question"))
|
||||
assert result["content"] == "just a question"
|
||||
|
||||
def test_only_strips_when_prefix_is_at_start(self) -> None:
|
||||
"""An embedded `<user_context>` block later in the message must NOT
|
||||
be stripped — only the leading prefix is server-injected."""
|
||||
|
||||
content = (
|
||||
"I copied this from somewhere: <user_context>\nfoo\n</user_context>\n\n"
|
||||
)
|
||||
result = _strip_injected_context(self._msg("user", content))
|
||||
assert result["content"] == content
|
||||
|
||||
def test_does_not_strip_with_only_single_newline_separator(self) -> None:
|
||||
"""The strip regex requires `\\n\\n` after the closing tag — a single
|
||||
newline indicates a different format and must not be touched."""
|
||||
|
||||
content = "<user_context>\nfoo\n</user_context>\nhello"
|
||||
result = _strip_injected_context(self._msg("user", content))
|
||||
assert result["content"] == content
|
||||
|
||||
def test_assistant_messages_pass_through(self) -> None:
|
||||
|
||||
original = "<user_context>\nfoo\n</user_context>\n\nhi"
|
||||
result = _strip_injected_context(self._msg("assistant", original))
|
||||
assert result["content"] == original
|
||||
|
||||
def test_non_string_content_passes_through(self) -> None:
|
||||
"""Multimodal / structured content (e.g. list of blocks) is not a
|
||||
string and must not be touched by the strip helper."""
|
||||
|
||||
blocks = [{"type": "text", "text": "hello"}]
|
||||
result = _strip_injected_context(self._msg("user", blocks))
|
||||
assert result["content"] is blocks
|
||||
|
||||
def test_strip_with_multiline_understanding(self) -> None:
|
||||
"""The understanding payload spans multiple lines (markdown headings,
|
||||
bullet points). `re.DOTALL` must allow the regex to span them."""
|
||||
|
||||
original = (
|
||||
"<user_context>\n"
|
||||
"# User Business Context\n\n"
|
||||
"## User\nName: Alice\n\n"
|
||||
"## Business\nCompany: Acme\n"
|
||||
"</user_context>\n\nactual question"
|
||||
)
|
||||
result = _strip_injected_context(self._msg("user", original))
|
||||
assert result["content"] == "actual question"
|
||||
|
||||
def test_strip_when_message_is_only_the_prefix(self) -> None:
|
||||
"""An empty user message gets injected with just the prefix; the
|
||||
strip should yield an empty string."""
|
||||
|
||||
original = "<user_context>\nctx\n</user_context>\n\n"
|
||||
result = _strip_injected_context(self._msg("user", original))
|
||||
assert result["content"] == ""
|
||||
|
||||
def test_does_not_mutate_original_dict(self) -> None:
|
||||
"""The helper must return a copy — the original dict stays intact."""
|
||||
original_content = "<user_context>\nctx\n</user_context>\n\nhello"
|
||||
msg = self._msg("user", original_content)
|
||||
result = _strip_injected_context(msg)
|
||||
assert result["content"] == "hello"
|
||||
assert msg["content"] == original_content
|
||||
assert result is not msg
|
||||
|
||||
def test_no_role_field_does_not_crash(self) -> None:
|
||||
|
||||
msg = {"content": "hello"}
|
||||
result = _strip_injected_context(msg)
|
||||
# Without a role, the helper short-circuits without touching content.
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_204_and_awaits_registry(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_disconnect.assert_awaited_once_with("sess-1")
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_404_when_session_missing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/unknown-session/stream")
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
|
||||
|
||||
|
||||
def _make_paginated_messages(
|
||||
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
|
||||
):
|
||||
"""Return a mock PaginatedMessages and configure the DB patch."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.db import PaginatedMessages
|
||||
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
now = datetime.now(UTC)
|
||||
session_info = ChatSessionInfo(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
metadata=ChatSessionMetadata(),
|
||||
)
|
||||
page = PaginatedMessages(
|
||||
messages=[ChatMessage(role="user", content="hello", sequence=0)],
|
||||
has_more=has_more,
|
||||
oldest_sequence=0,
|
||||
session=session_info,
|
||||
)
|
||||
mock_paginate = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_messages_paginated",
|
||||
new_callable=AsyncMock,
|
||||
return_value=page,
|
||||
)
|
||||
return page, mock_paginate
|
||||
|
||||
|
||||
def test_get_session_returns_backward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""All sessions use backward (newest-first) pagination."""
|
||||
_make_paginated_messages(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["oldest_sequence"] == 0
|
||||
assert "forward_paginated" not in data
|
||||
assert "newest_sequence" not in data
|
||||
|
||||
@@ -40,11 +40,15 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -110,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -148,6 +153,7 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +230,9 @@ async def callback(
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
@@ -238,6 +247,7 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
@@ -332,6 +342,11 @@ async def delete_credentials(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
@@ -342,6 +357,11 @@ async def delete_credentials(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -276,3 +277,294 @@ class TestCreateCredentialNoSecretInResponse:
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
@@ -12,6 +12,7 @@ import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.api.features.library.db import _fetch_schedule_info
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -117,4 +118,5 @@ async def add_graph_to_library(
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
|
||||
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
|
||||
|
||||
@@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
@@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
@@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
@@ -43,6 +44,65 @@ config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||
"""Fetch execution counts per graph in a single batched query."""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": {"in": graph_ids},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
return {
|
||||
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_schedule_info(
|
||||
user_id: str, graph_id: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Fetch a map of graph_id → earliest next_run_time ISO string.
|
||||
|
||||
When `graph_id` is provided, the scheduler query is narrowed to that graph,
|
||||
which is cheaper for single-agent lookups (detail page, post-update, etc.).
|
||||
"""
|
||||
try:
|
||||
scheduler_client = get_scheduler_client()
|
||||
schedules = await scheduler_client.get_execution_schedules(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
earliest: dict[str, tuple[datetime, str]] = {}
|
||||
for s in schedules:
|
||||
parsed = _parse_iso_datetime(s.next_run_time)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = earliest.get(s.graph_id)
|
||||
if current is None or parsed < current[0]:
|
||||
earliest[s.graph_id] = (parsed, s.next_run_time)
|
||||
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_iso_datetime(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse schedule next_run_time: %s", value)
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
user_id: str,
|
||||
search_term: Optional[str] = None,
|
||||
@@ -137,12 +197,22 @@ async def list_library_agents(
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -214,12 +284,22 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
where={"userId": store_listing.owningUserId}
|
||||
)
|
||||
|
||||
schedule_info = (
|
||||
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
|
||||
if library_agent.AgentGraph
|
||||
else {}
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(
|
||||
library_agent,
|
||||
sub_graphs=(
|
||||
@@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
),
|
||||
store_listing=store_listing,
|
||||
profile=profile,
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id(
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
if not agent:
|
||||
return None
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
@@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id(
|
||||
assert agent.AgentGraph # make type checker happy
|
||||
# Include sub-graphs so we can make a full credentials input schema
|
||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(
|
||||
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
|
||||
)
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -481,6 +574,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
include=library_agent_include(
|
||||
@@ -495,7 +593,11 @@ async def create_library_agent(
|
||||
for agent, graph in zip(library_agents, graph_entries):
|
||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in library_agents
|
||||
]
|
||||
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
@@ -557,7 +659,8 @@ async def update_agent_version_in_library(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
|
||||
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
@@ -1462,7 +1565,11 @@ async def bulk_move_agents_to_folder(
|
||||
),
|
||||
)
|
||||
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
|
||||
@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_favorite_library_agents(mocker):
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="fav1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-fav",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=True,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-fav",
|
||||
version=1,
|
||||
name="Favorite Agent",
|
||||
description="My Favorite",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||
)
|
||||
|
||||
result = await db.list_favorite_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "fav1"
|
||||
assert result.agents[0].name == "Favorite Agent"
|
||||
assert result.agents[0].graph_id == "agent-fav"
|
||||
assert result.pagination.total_items == 1
|
||||
assert result.pagination.total_pages == 1
|
||||
assert result.pagination.current_page == 1
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua-bad",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-bad",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-bad",
|
||||
version=1,
|
||||
name="Bad Agent",
|
||||
description="",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 0
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_empty_graph_ids():
|
||||
result = await db._fetch_execution_counts("user-1", [])
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||
return_value=[
|
||||
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||
]
|
||||
)
|
||||
|
||||
result = await db._fetch_execution_counts(
|
||||
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||
)
|
||||
|
||||
assert result == {"graph-1": 5, "graph-2": 2}
|
||||
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
|
||||
@@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
is_scheduled: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Whether this agent has active execution schedules",
|
||||
)
|
||||
next_scheduled_run: str | None = pydantic.Field(
|
||||
default=None,
|
||||
description="ISO 8601 timestamp of the next scheduled run, if any",
|
||||
)
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||
profile: Optional[prisma.models.Profile] = None,
|
||||
execution_count_override: Optional[int] = None,
|
||||
schedule_info: Optional[dict[str, str]] = None,
|
||||
) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
@@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = len(executions)
|
||||
execution_count = (
|
||||
execution_count_override
|
||||
if execution_count_override is not None
|
||||
else len(executions)
|
||||
)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
if executions and execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
@@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
|
||||
next_scheduled_run=(
|
||||
schedule_info.get(agent.agentGraphId) if schedule_info else None
|
||||
),
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,66 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
def _make_library_agent(
|
||||
*,
|
||||
graph_id: str = "g1",
|
||||
executions: list | None = None,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
return prisma.models.LibraryAgent(
|
||||
id="la1",
|
||||
userId="u1",
|
||||
agentGraphId=graph_id,
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
name="Agent",
|
||||
description="Desc",
|
||||
userId="u1",
|
||||
isActive=True,
|
||||
createdAt=datetime.datetime.now(),
|
||||
Executions=executions,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_execution_count_override_covers_success_rate():
|
||||
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec-1",
|
||||
agentGraphId="g1",
|
||||
agentGraphVersion=1,
|
||||
userId="u1",
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent = _make_library_agent(executions=[exec1])
|
||||
|
||||
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||
|
||||
assert result.execution_count == 1
|
||||
assert result.success_rate is not None
|
||||
assert result.success_rate == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db(test_user_id: str):
|
||||
# Create mock DB agent
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_onboarding_profile_success(mocker):
|
||||
mock_extract = mocker.patch(
|
||||
"backend.api.features.v1.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.v1.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
|
||||
user_name="John",
|
||||
user_role="Founder/CEO",
|
||||
pain_points=["Finding leads"],
|
||||
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
|
||||
)
|
||||
mock_upsert.return_value = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={
|
||||
"user_name": "John",
|
||||
"user_role": "Founder/CEO",
|
||||
"pain_points": ["Finding leads", "Email & outreach"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once()
|
||||
|
||||
|
||||
def test_onboarding_profile_missing_fields():
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={"user_name": "John"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -0,0 +1,805 @@
|
||||
"""Tests for subscription tier API endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
import stripe
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma.enums import SubscriptionTier
|
||||
|
||||
from .v1 import _validate_checkout_redirect_url, v1_router
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
TEST_FRONTEND_ORIGIN = "https://app.example.com"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> fastapi.testclient.TestClient:
|
||||
"""Fresh FastAPI app + client per test with auth override applied.
|
||||
|
||||
Using a fixture avoids the leaky global-app + try/finally teardown pattern:
|
||||
if a test body raises before teardown_auth runs, dependency overrides were
|
||||
previously leaking into subsequent tests.
|
||||
"""
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
|
||||
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
|
||||
return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"}
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = override_get_jwt_payload
|
||||
try:
|
||||
yield fastapi.testclient.TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Pin the configured frontend origin used by the open-redirect guard."""
|
||||
from backend.api.features import v1 as v1_mod
|
||||
|
||||
mocker.patch.object(
|
||||
v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,expected",
|
||||
[
|
||||
# Valid URLs matching the configured frontend origin
|
||||
(f"{TEST_FRONTEND_ORIGIN}/success", True),
|
||||
(f"{TEST_FRONTEND_ORIGIN}/cancel?ref=abc", True),
|
||||
# Wrong origin
|
||||
("https://evil.example.org/phish", False),
|
||||
("https://evil.example.org", False),
|
||||
# @ in URL (user:pass@host attack)
|
||||
(f"https://attacker.example.com@{TEST_FRONTEND_ORIGIN}/ok", False),
|
||||
# Backslash normalisation attack
|
||||
(f"https:{TEST_FRONTEND_ORIGIN}\\@attacker.example.com/ok", False),
|
||||
# javascript: scheme
|
||||
("javascript:alert(1)", False),
|
||||
# Empty string
|
||||
("", False),
|
||||
# Control character (U+0000) in URL
|
||||
(f"{TEST_FRONTEND_ORIGIN}/ok\x00evil", False),
|
||||
# Non-http scheme
|
||||
(f"ftp://{TEST_FRONTEND_ORIGIN}/ok", False),
|
||||
],
|
||||
)
|
||||
def test_validate_checkout_redirect_url(
|
||||
url: str,
|
||||
expected: bool,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""_validate_checkout_redirect_url rejects adversarial inputs."""
|
||||
from backend.api.features import v1 as v1_mod
|
||||
|
||||
mocker.patch.object(
|
||||
v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN
|
||||
)
|
||||
assert _validate_checkout_redirect_url(url) is expected
|
||||
|
||||
|
||||
def test_get_subscription_status_pro(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription returns PRO tier with Stripe price for a PRO user."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return "price_pro" if tier == SubscriptionTier.PRO else None
|
||||
|
||||
async def mock_stripe_price_amount(price_id: str) -> int:
|
||||
return 1999 if price_id == "price_pro" else 0
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1._get_stripe_price_amount",
|
||||
side_effect=mock_stripe_price_amount,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=500,
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "PRO"
|
||||
assert data["monthly_cost"] == 1999
|
||||
assert data["tier_costs"]["PRO"] == 1999
|
||||
assert data["tier_costs"]["BUSINESS"] == 0
|
||||
assert data["tier_costs"]["FREE"] == 0
|
||||
assert data["proration_credit_cents"] == 500
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_free(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription when subscription_tier is None defaults to FREE."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == SubscriptionTier.FREE.value
|
||||
assert data["monthly_cost"] == 0
|
||||
assert data["tier_costs"] == {
|
||||
"FREE": 0,
|
||||
"PRO": 0,
|
||||
"BUSINESS": 0,
|
||||
"ENTERPRISE": 0,
|
||||
}
|
||||
assert data["proration_credit_cents"] == 0
|
||||
|
||||
|
||||
def test_get_subscription_status_stripe_error_falls_back_to_zero(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription returns cost=0 when Stripe price fetch fails (returns None).
|
||||
|
||||
_get_stripe_price_amount returns None on StripeError so the error state is
|
||||
not cached. The endpoint must treat None as 0 — not raise or return invalid data.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return "price_pro" if tier == SubscriptionTier.PRO else None
|
||||
|
||||
async def mock_stripe_price_amount_none(price_id: str) -> None:
|
||||
return None
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1._get_stripe_price_amount",
|
||||
side_effect=mock_stripe_price_amount_none,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "PRO"
|
||||
# When Stripe returns None, cost falls back to 0
|
||||
assert data["monthly_cost"] == 0
|
||||
assert data["tier_costs"]["PRO"] == 0
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_no_payment(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription to FREE tier when payment disabled skips Stripe."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_disabled(*args, **kwargs):
|
||||
return False
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_disabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
|
||||
|
||||
def test_update_subscription_tier_paid_beta_user(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription for paid tier when payment disabled returns 422."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
|
||||
async def mock_feature_disabled(*args, **kwargs):
|
||||
return False
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_disabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "PRO"})
|
||||
|
||||
assert response.status_code == 422
|
||||
assert "not available" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_subscription_tier_paid_requires_urls(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription for paid tier without success/cancel URLs returns 422."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "PRO"})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_subscription_tier_creates_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription creates Stripe Checkout Session for paid upgrade."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
return_value="https://checkout.stripe.com/pay/cs_test_abc",
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc"
|
||||
|
||||
|
||||
def test_update_subscription_tier_rejects_open_redirect(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription rejects success/cancel URLs outside the frontend origin."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": "https://evil.example.org/phish",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_enterprise_blocked(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""ENTERPRISE users cannot self-service change tiers — must get 403."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.ENTERPRISE
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
set_tier_mock = mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
set_tier_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_same_tier_is_noop(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription for the user's current paid tier returns 200 with empty URL.
|
||||
|
||||
Without this guard a duplicate POST (double-click, browser retry, stale page) would
|
||||
create a second Stripe Checkout Session for the same price, potentially billing the
|
||||
user twice until the webhook reconciliation fires.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE schedules Stripe cancellation at period end.
|
||||
|
||||
The DB tier must NOT be updated immediately — the customer.subscription.deleted
|
||||
webhook fires at period end and downgrades to FREE then.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mock_cancel = mocker.patch(
|
||||
"backend.api.features.v1.cancel_stripe_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_set_tier = mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cancel.assert_awaited_once()
|
||||
mock_set_tier.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_cancel_failure_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.cancel_stripe_subscription",
|
||||
side_effect=stripe.StripeError(
|
||||
"You did not provide an API key — internal detail that must not leak"
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
|
||||
assert response.status_code == 502
|
||||
detail = response.json()["detail"]
|
||||
# The raw Stripe error message must not appear in the client-facing detail.
|
||||
assert "API key" not in detail
|
||||
assert "contact support" in detail.lower()
|
||||
|
||||
|
||||
def test_stripe_webhook_unconfigured_secret_returns_503(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Stripe webhook endpoint returns 503 when STRIPE_WEBHOOK_SECRET is not set.
|
||||
|
||||
An empty webhook secret allows HMAC forgery: an attacker can compute a valid
|
||||
HMAC signature over the same empty key. The handler must reject all requests
|
||||
when the secret is unconfigured rather than proceeding with signature verification.
|
||||
"""
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="",
|
||||
)
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=fake"},
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
def test_stripe_webhook_dispatches_subscription_events(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/stripe_webhook routes customer.subscription.created to sync handler."""
|
||||
stripe_sub_obj = {
|
||||
"id": "sub_test",
|
||||
"customer": "cus_test",
|
||||
"status": "active",
|
||||
"items": {"data": [{"price": {"id": "price_pro"}}]},
|
||||
}
|
||||
event = {
|
||||
"type": "customer.subscription.created",
|
||||
"data": {"object": stripe_sub_obj},
|
||||
}
|
||||
|
||||
# Ensure the webhook secret guard passes (non-empty secret required).
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="whsec_test",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Webhook.construct_event",
|
||||
return_value=event,
|
||||
)
|
||||
sync_mock = mocker.patch(
|
||||
"backend.api.features.v1.sync_subscription_from_stripe",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=abc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
sync_mock.assert_awaited_once_with(stripe_sub_obj)
|
||||
|
||||
|
||||
def test_stripe_webhook_dispatches_invoice_payment_failed(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/stripe_webhook routes invoice.payment_failed to the failure handler."""
|
||||
invoice_obj = {
|
||||
"customer": "cus_test",
|
||||
"subscription": "sub_test",
|
||||
"amount_due": 1999,
|
||||
}
|
||||
event = {
|
||||
"type": "invoice.payment_failed",
|
||||
"data": {"object": invoice_obj},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="whsec_test",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Webhook.construct_event",
|
||||
return_value=event,
|
||||
)
|
||||
failure_mock = mocker.patch(
|
||||
"backend.api.features.v1.handle_subscription_payment_failure",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=abc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
failure_mock.assert_awaited_once_with(invoice_obj)
|
||||
|
||||
|
||||
def test_update_subscription_tier_paid_to_paid_modifies_subscription(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription modifies existing subscription for paid→paid changes."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes.
|
||||
|
||||
When modify_stripe_subscription_for_tier returns False (no Stripe subscription
|
||||
found — admin-granted tier), the endpoint must update the DB tier directly and
|
||||
return 200 with url="", rather than falling through to Checkout Session creation.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
# Return False = no Stripe subscription (admin-granted tier)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
set_tier_mock = mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
# DB tier updated directly — no Stripe Checkout Session created
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription returns 502 when Stripe modification fails."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=stripe.StripeError("connection error"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_no_stripe_subscription(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE when no Stripe subscription exists updates DB tier directly.
|
||||
|
||||
Admin-granted paid tiers have no associated Stripe subscription. When such a
|
||||
user requests a self-service downgrade, cancel_stripe_subscription returns False
|
||||
(nothing to cancel), so the endpoint must immediately call set_subscription_tier
|
||||
rather than waiting for a webhook that will never arrive.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
# Simulate no active Stripe subscriptions — returns False
|
||||
cancel_mock = mocker.patch(
|
||||
"backend.api.features.v1.cancel_stripe_subscription",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
set_tier_mock = mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
# DB tier must be updated immediately — no webhook will fire for a missing sub
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE)
|
||||
@@ -5,7 +5,8 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Sequence, get_args
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -24,6 +25,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import SubscriptionTier
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
@@ -50,9 +52,17 @@ from backend.data.credit import (
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
@@ -63,12 +73,17 @@ from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
complete_onboarding_step,
|
||||
complete_re_run_agent,
|
||||
format_onboarding_for_extraction,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.tally import extract_business_understanding
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
@@ -282,35 +297,33 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding status check."""
|
||||
class OnboardingProfileRequest(pydantic.BaseModel):
|
||||
"""Request body for onboarding profile submission."""
|
||||
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
user_name: str = pydantic.Field(min_length=1, max_length=100)
|
||||
user_role: str = pydantic.Field(min_length=1, max_length=100)
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding completion check."""
|
||||
|
||||
is_completed: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
"/onboarding/completed",
|
||||
summary="Check if onboarding is completed",
|
||||
tags=["onboarding", "public"],
|
||||
response_model=OnboardingStatusResponse,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled(
|
||||
async def is_onboarding_completed(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
|
||||
)
|
||||
|
||||
|
||||
@@ -325,6 +338,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await reset_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/onboarding/profile",
|
||||
summary="Submit onboarding profile",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def submit_onboarding_profile(
|
||||
data: OnboardingProfileRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
formatted = format_onboarding_for_extraction(
|
||||
user_name=data.user_name,
|
||||
user_role=data.user_role,
|
||||
pain_points=data.pain_points,
|
||||
)
|
||||
|
||||
try:
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
except Exception:
|
||||
understanding_input = BusinessUnderstandingInput.model_construct()
|
||||
|
||||
# Ensure the direct fields are set even if LLM missed them
|
||||
understanding_input.user_name = data.user_name
|
||||
understanding_input.user_role = data.user_role
|
||||
if not understanding_input.pain_points:
|
||||
understanding_input.pain_points = data.pain_points
|
||||
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
@@ -626,9 +671,12 @@ async def configure_user_auto_top_up(
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
)
|
||||
try:
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
return "Auto top-up settings updated"
|
||||
|
||||
|
||||
@@ -644,41 +692,371 @@ async def get_user_auto_top_up(
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["FREE", "PRO", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionCheckoutResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/subscription",
|
||||
summary="Get subscription tier, current cost, and all tier costs",
|
||||
operation_id="getSubscriptionStatus",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
|
||||
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
|
||||
price_ids = await asyncio.gather(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/subscription",
|
||||
summary="Start a Stripe Checkout session to upgrade subscription tier",
|
||||
operation_id="updateSubscriptionTier",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_subscription_tier(
|
||||
request: SubscriptionTierRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionCheckoutResponse:
|
||||
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
|
||||
tier = SubscriptionTier(request.tier)
|
||||
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||
)
|
||||
|
||||
payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
# Invalid payload
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
# Invalid signature
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
):
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
if event_type in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
|
||||
file_count: int
|
||||
|
||||
|
||||
class WorkspaceFileItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
files: list[WorkspaceFileItem]
|
||||
offset: int = 0
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
operation_id="getWorkspaceDownloadFileById",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -158,6 +175,7 @@ async def download_file(
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
operation_id="deleteWorkspaceFile",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -183,6 +201,7 @@ async def delete_workspace_file(
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
operation_id="uploadWorkspaceFile",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -196,6 +215,9 @@ async def upload_file(
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
# Empty-string session_id drops session scoping; normalize to None.
|
||||
session_id = session_id or None
|
||||
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
@@ -250,16 +272,27 @@ async def upload_file(
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
try:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to soft-delete over-quota file {workspace_file.id} "
|
||||
f"in workspace {workspace.id}: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
@@ -281,6 +314,7 @@ async def upload_file(
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
operation_id="getWorkspaceStorageUsage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -301,3 +335,57 @@ async def get_storage_usage(
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files",
|
||||
summary="List workspace files",
|
||||
operation_id="listWorkspaceFiles",
|
||||
)
|
||||
async def list_workspace_files(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
session_id: str | None = Query(default=None),
|
||||
limit: int = Query(default=200, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
List files in the user's workspace.
|
||||
|
||||
When session_id is provided, only files for that session are returned.
|
||||
Otherwise, all files across sessions are listed. Results are paginated
|
||||
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
|
||||
"""
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Treat empty-string session_id the same as omitted — an empty value
|
||||
# would otherwise silently list files across every session instead of
|
||||
# scoping to one.
|
||||
session_id = session_id or None
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
include_all = session_id is None
|
||||
# Fetch one extra to compute has_more without a separate count query.
|
||||
files = await manager.list_files(
|
||||
limit=limit + 1,
|
||||
offset=offset,
|
||||
include_all_sessions=include_all,
|
||||
)
|
||||
has_more = len(files) > limit
|
||||
page = files[:limit]
|
||||
|
||||
return ListFilesResponse(
|
||||
files=[
|
||||
WorkspaceFileItem(
|
||||
id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
metadata=f.metadata or {},
|
||||
created_at=f.created_at.isoformat(),
|
||||
)
|
||||
for f in page
|
||||
],
|
||||
offset=offset,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
@@ -1,48 +1,28 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.api.features.workspace.routes import router
|
||||
from backend.data.workspace import Workspace, WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
"""Mirror the production ValueError → 400 mapping from the REST app."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
|
||||
return Workspace(
|
||||
id="ws-001",
|
||||
user_id=user_id,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"storage_path": "local://test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _make_file_mock(**overrides) -> MagicMock:
|
||||
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"metadata": {},
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock(spec=WorkspaceFile)
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
# -- list_workspace_files tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
|
||||
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["files"]) == 2
|
||||
assert data["has_more"] is False
|
||||
assert data["offset"] == 0
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_scopes_to_session_when_provided(
|
||||
mock_manager_cls, mock_get_workspace, test_user_id
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?session_id=sess-123")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["files"] == []
|
||||
assert data["has_more"] is False
|
||||
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=False
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["files"][0]["metadata"] == {}
|
||||
|
||||
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_instance.write_file.assert_called_once()
|
||||
call_kwargs = mock_instance.write_file.call_args
|
||||
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.side_effect = ValueError("File already exists at path")
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("dup.txt", b"content", "text/plain")},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
# -- Restored upload/download/delete security + invariant tests --
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
_MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-001",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
name="hello.txt",
|
||||
path="/sessions/sess-1/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_happy_path(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_exceeds_max_file_size(mocker):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
def test_upload_post_write_quota_race(mocker):
|
||||
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_any_extension(mocker):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_file_without_extension(mocker):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_strips_path_components(mocker):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_download_file_not_found(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_success(mocker):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_not_found(mocker):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_no_workspace(mocker):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"""write_file raises ValueError("File too large: …") → must map to 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "File too large" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"""Non-'File too large' ValueErrors from write_file stay as 409."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.text
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_true_when_limit_exceeded(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
# Backend was asked for limit+1=3, and returned exactly 3 items.
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt"),
|
||||
_make_file(id="f2", name="b.txt"),
|
||||
_make_file(id="f3", name="c.txt"),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is True
|
||||
assert len(data["files"]) == 2
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=3, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_false_when_exactly_page_size(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Exactly `limit` rows means we're on the last page — has_more=False."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
assert len(data["files"]) == 2
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?offset=50&limit=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["offset"] == 50
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -118,6 +119,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -324,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||
from backend.data.model import ContributorDetails
|
||||
|
||||
from ..data.graph import Link
|
||||
|
||||
@@ -420,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Return extra runtime cost to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by charge_usage
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -455,8 +469,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
disabled: If the block is disabled, it will not be available for execution.
|
||||
static_output: Whether the output links of the block are static by default.
|
||||
"""
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
self.id = id
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
@@ -474,7 +486,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -554,7 +566,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
@@ -698,13 +710,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -207,6 +207,9 @@ class AIConditionBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
cache_read_token_count=response.cache_read_tokens,
|
||||
cache_creation_token_count=response.cache_creation_tokens,
|
||||
provider_cost=response.provider_cost,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
@@ -47,7 +47,13 @@ def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
def _mock_llm_response(
|
||||
response_text: str,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
provider_cost: float | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
@@ -56,6 +62,9 @@ def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
provider_cost=provider_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -145,3 +154,35 @@ class TestExceptionPropagation:
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: cache tokens and provider_cost must be propagated to stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheTokenPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_propagated_to_stats(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""cache_read_tokens and cache_creation_tokens must be forwarded to
|
||||
NodeExecutionStats so that usage dashboards count cached tokens."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def spy_llm(**kwargs):
|
||||
return _mock_llm_response(
|
||||
"true",
|
||||
cache_read_tokens=7,
|
||||
cache_creation_tokens=3,
|
||||
provider_cost=0.0012,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(block, "llm_call", spy_llm)
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert block.execution_stats.cache_read_token_count == 7
|
||||
assert block.execution_stats.cache_creation_token_count == 3
|
||||
assert block.execution_stats.provider_cost == 0.0012
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
@@ -32,6 +33,10 @@ logger = logging.getLogger(__name__)
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class SubAgentRecursionError(RuntimeError):
|
||||
"""Raised when the sub-agent nesting depth limit is exceeded."""
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
@@ -146,6 +151,21 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -232,11 +252,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -367,7 +387,10 @@ class AutoPilotBlock(Block):
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id,
|
||||
dry_run=input_data.dry_run or execution_context.dry_run,
|
||||
)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
@@ -392,8 +415,41 @@ class AutoPilotBlock(Block):
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except SubAgentRecursionError as exc:
|
||||
# Deliberate block — re-enqueueing would immediately hit the limit
|
||||
# again, so skip recovery and just surface the error.
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
# Recovery enqueue must happen BEFORE yielding "error": the block
|
||||
# framework (_base.execute) raises BlockExecutionError immediately
|
||||
# when it sees ("error", ...) and stops consuming the generator,
|
||||
# so any code after that yield is dead code in production.
|
||||
effective_prompt = input_data.prompt
|
||||
if input_data.system_context:
|
||||
effective_prompt = (
|
||||
f"[System Context: {input_data.system_context}]\n\n"
|
||||
f"{input_data.prompt}"
|
||||
)
|
||||
try:
|
||||
await _enqueue_for_recovery(
|
||||
sid,
|
||||
execution_context.user_id,
|
||||
effective_prompt,
|
||||
input_data.dry_run or execution_context.dry_run,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Task cancelled during recovery — still yield the error
|
||||
# so the session_id + error pair is visible before re-raising.
|
||||
yield "error", str(exc)
|
||||
raise
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"AutoPilot session %s: recovery enqueue raised unexpectedly",
|
||||
sid[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
@@ -421,13 +477,13 @@ def _check_recursion(
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
SubAgentRecursionError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
raise SubAgentRecursionError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
@@ -518,3 +574,51 @@ def _merge_inherited_permissions(
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recovery helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _enqueue_for_recovery(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
"""Re-enqueue an orphaned sub-agent session so a fresh executor picks it up.
|
||||
|
||||
When ``execute_copilot`` raises an unexpected exception the sub-agent
|
||||
session is left with ``last_role=user`` and no active consumer — identical
|
||||
to the state that caused Toran's reports of silent sub-agents. Publishing
|
||||
the original prompt back to the copilot queue lets the executor service
|
||||
resume the session without manual intervention.
|
||||
|
||||
Skipped for dry-run sessions (no real consumers listen to the queue for
|
||||
simulated sessions). Any failure to publish is logged and swallowed so
|
||||
it never masks the original exception.
|
||||
"""
|
||||
if dry_run:
|
||||
return
|
||||
try:
|
||||
from backend.copilot.executor.utils import ( # avoid circular import
|
||||
enqueue_copilot_turn,
|
||||
)
|
||||
|
||||
await asyncio.wait_for(
|
||||
enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
turn_id=str(uuid.uuid4()),
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
logger.info("AutoPilot session %s enqueued for recovery", session_id[:12])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"AutoPilot session %s: failed to enqueue for recovery",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,712 @@
|
||||
"""Unit tests for merge_stats cost tracking in individual blocks.
|
||||
|
||||
Covers the exa code_context, exa contents, and apollo organization blocks
|
||||
to verify provider cost is correctly extracted and reported.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_EXA_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_EXA_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_EXA_CREDENTIALS.provider,
|
||||
"id": TEST_EXA_CREDENTIALS.id,
|
||||
"type": TEST_EXA_CREDENTIALS.type,
|
||||
"title": TEST_EXA_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaCodeContextBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_float_cost(self):
|
||||
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "how to use hooks",
|
||||
"response": "Here are some examples...",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="how to use hooks",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_dollars_does_not_raise(self):
|
||||
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "query",
|
||||
"response": "response",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
merge_calls: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merge_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_is_tracked(self):
|
||||
"""A zero cost_dollars string '0.0' should still be recorded."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "query",
|
||||
"response": "...",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaContentsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_cost_dollars_total(self):
|
||||
"""provider_cost equals response.cost_dollars.total when present."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = cost_dollars
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_cost_dollars_absent(self):
|
||||
"""When response.cost_dollars is None, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = None
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchOrganizationsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_org_count(self):
|
||||
"""provider_cost == number of returned organizations, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Organization
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
|
||||
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_orgs,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(3.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_org_list_tracks_zero(self):
|
||||
"""An empty organization list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JinaEmbeddingBlock — token count from usage.total_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJinaEmbeddingBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_token_count(self):
|
||||
"""provider token count is recorded when API returns usage.total_tokens."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
"usage": {"total_tokens": 42},
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello world"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].input_token_count == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_usage_absent(self):
|
||||
"""When API response omits usage field, merge_stats is not called."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UnrealTextToSpeechBlock — character count from input text length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
test_text = "Hello, world!"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text=test_text,
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text="",
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GoogleMapsSearchBlock — item count from search_places results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleMapsSearchBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_place_count(self):
|
||||
"""provider_cost equals number of returned places, type == 'items'."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
|
||||
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=fake_places,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="coffee shops",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 4.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_tracks_zero(self):
|
||||
"""Zero places returned results in provider_cost=0.0."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="nothing here",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SmartLeadAddLeadsBlock — item count from lead_list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartLeadAddLeadsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_lead_count(self):
|
||||
"""provider_cost equals number of leads uploaded, type == 'items'."""
|
||||
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsToCampaignResponse,
|
||||
LeadInput,
|
||||
)
|
||||
|
||||
block = AddLeadToCampaignBlock()
|
||||
|
||||
fake_leads = [
|
||||
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
|
||||
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
|
||||
]
|
||||
fake_response = AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=2,
|
||||
total_leads=2,
|
||||
block_count=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
invalid_emails=[],
|
||||
already_added_to_campaign=0,
|
||||
unsubscribed_leads=[],
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
bounce_count=0,
|
||||
)
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
AddLeadToCampaignBlock,
|
||||
"add_leads_to_campaign",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = AddLeadToCampaignBlock.Input(
|
||||
campaign_id=123,
|
||||
lead_list=fake_leads,
|
||||
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=SL_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 2.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPeopleBlock — item count from people list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchPeopleBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_people_count(self):
|
||||
"""provider_cost equals number of returned people, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_people,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(5.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_people_list_tracks_zero(self):
|
||||
"""An empty people list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
"""Tests for cost tracking in Exa blocks.
|
||||
|
||||
Covers the cost_dollars → provider_cost → merge_stats path for both
|
||||
ExaContentsBlock and ExaCodeContextBlock.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
class TestExaCodeContextCostTracking:
|
||||
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_cost_string_is_parsed_and_merged(self):
|
||||
"""A numeric cost string like '0.005' is merged as provider_cost."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "test query",
|
||||
"response": "some code",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
assert any(k == "cost_dollars" for k, _ in outputs)
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_string_does_not_raise(self):
|
||||
"""A non-numeric cost_dollars value is swallowed silently."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "test",
|
||||
"response": "code",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
# No merge_stats call because float() raised ValueError
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_string_is_merged(self):
|
||||
"""'0.0' is a valid cost — should still be tracked."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "free query",
|
||||
"response": "result",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaContentsCostTracking:
|
||||
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_dollars_is_merged(self):
|
||||
"""A total of 0.0 (free tier) should still be merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCreateResearchBlock — cost_dollars from completed poll response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMPLETED_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "completed",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
"finishedAt": 1700000060000,
|
||||
"costDollars": {
|
||||
"total": 0.05,
|
||||
"numSearches": 3,
|
||||
"numPages": 10,
|
||||
"reasoningTokens": 500,
|
||||
},
|
||||
"output": {"content": "Research findings...", "parsed": None},
|
||||
}
|
||||
|
||||
PENDING_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "pending",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestExaCreateResearchBlockCostTracking:
|
||||
"""ExaCreateResearchBlock merges cost from completed poll response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total when poll returns completed."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed response has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaGetResearchBlock — cost_dollars from single GET response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaGetResearchBlockCostTracking:
|
||||
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_from_completed_research(self):
|
||||
"""merge_stats called with provider_cost=total when research has costDollars."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaWaitForResearchBlock — cost_dollars from polling response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaWaitForResearchBlockCostTracking:
|
||||
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total once polling returns completed."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(places)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
@@ -13,6 +14,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -104,7 +106,6 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
@@ -201,10 +202,25 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_4_20 = "x-ai/grok-4.20"
|
||||
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Z.ai (Zhipu) models
|
||||
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
|
||||
ZAI_GLM_4_5 = "z-ai/glm-4.5"
|
||||
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
|
||||
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
|
||||
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
|
||||
ZAI_GLM_4_6 = "z-ai/glm-4.6"
|
||||
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
|
||||
ZAI_GLM_4_7 = "z-ai/glm-4.7"
|
||||
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
|
||||
ZAI_GLM_5 = "z-ai/glm-5"
|
||||
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
|
||||
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
@@ -612,6 +628,18 @@ MODEL_METADATA = {
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_20: ModelMetadata(
|
||||
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
|
||||
"open_router",
|
||||
2000000,
|
||||
100000,
|
||||
"Grok 4.20 Multi-Agent",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
3,
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
@@ -630,6 +658,43 @@ MODEL_METADATA = {
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# https://openrouter.ai/models?q=z-ai
|
||||
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
|
||||
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
|
||||
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
|
||||
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6: ModelMetadata(
|
||||
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7: ModelMetadata(
|
||||
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
|
||||
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_5: ModelMetadata(
|
||||
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
@@ -686,17 +751,20 @@ class LLMResponse(BaseModel):
|
||||
tool_calls: Optional[List[ToolContentBlock]] | None
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.omit
|
||||
return anthropic.NOT_GIVEN
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
@@ -721,9 +789,41 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
|
||||
attribute. We use try/except AttributeError so that if the SDK ever drops
|
||||
or renames that attribute, the warning is visible in logs rather than
|
||||
silently degrading to no cost tracking.
|
||||
"""
|
||||
try:
|
||||
raw_resp = response._response # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"OpenAI SDK response missing _response attribute"
|
||||
" — OpenRouter cost tracking unavailable"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if not cost_header:
|
||||
return None
|
||||
cost = float(cost_header)
|
||||
if not math.isfinite(cost) or cost < 0:
|
||||
return None
|
||||
return cost
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -739,6 +839,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -797,6 +900,21 @@ async def llm_call(
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
# Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key
|
||||
# is configured, route direct-Anthropic models through OpenRouter instead. This
|
||||
# gives us the x-total-cost header for free, so provider_cost is always populated
|
||||
# without manual token-rate arithmetic.
|
||||
or_key = settings.secrets.open_router_api_key
|
||||
or_model_id: str | None = None
|
||||
if provider == "anthropic" and or_key:
|
||||
provider = "open_router"
|
||||
credentials = APIKeyCredentials(
|
||||
provider=ProviderName.OPEN_ROUTER,
|
||||
title="OpenRouter (auto)",
|
||||
api_key=SecretStr(or_key),
|
||||
)
|
||||
or_model_id = f"anthropic/{llm_model.value}"
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
messages=prompt,
|
||||
@@ -882,8 +1000,12 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
# Cache tool definitions alongside the system prompt.
|
||||
# Placing cache_control on the last tool caches all tool schemas as a
|
||||
# single prefix — reads cost 10% of normal input tokens.
|
||||
if isinstance(an_tools, list) and an_tools:
|
||||
an_tools[-1] = {**an_tools[-1], "cache_control": {"type": "ephemeral"}}
|
||||
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
@@ -906,14 +1028,34 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
# create_kwargs is built as a plain dict so we can conditionally add
|
||||
# the `system` field only when the prompt is non-empty. Anthropic's
|
||||
# API rejects empty text blocks (returns HTTP 400), so omitting the
|
||||
# field is the correct behaviour for whitespace-only prompts.
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
# `an_tools` may be anthropic.NOT_GIVEN when no tools were
|
||||
# configured. The SDK treats NOT_GIVEN as a sentinel meaning "omit
|
||||
# this field from the serialized request", so passing it here is
|
||||
# equivalent to not including the key at all — no `tools` field is
|
||||
# sent to the API in that case.
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
if sysprompt.strip():
|
||||
# Wrap the system prompt in a single cacheable text block.
|
||||
# The guard intentionally omits `system` for whitespace-only
|
||||
# prompts — Anthropic rejects empty text blocks with HTTP 400.
|
||||
create_kwargs["system"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sysprompt,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
resp = await client.messages.create(**create_kwargs)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
@@ -958,6 +1100,11 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
cache_read_tokens=getattr(resp.usage, "cache_read_input_tokens", None) or 0,
|
||||
cache_creation_tokens=getattr(
|
||||
resp.usage, "cache_creation_input_tokens", None
|
||||
)
|
||||
or 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
@@ -972,6 +1119,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1024,19 +1173,15 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=or_model_id or llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1049,6 +1194,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=extract_openrouter_cost(response),
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1073,12 +1219,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1108,6 +1250,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1144,6 +1288,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -1355,6 +1502,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
total_provider_cost: float | None = None
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
@@ -1372,12 +1520,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
# Accumulate token counts and provider_cost for every attempt
|
||||
# (each call costs tokens and USD, regardless of validation outcome).
|
||||
token_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
cache_read_token_count=llm_response.cache_read_tokens,
|
||||
cache_creation_token_count=llm_response.cache_creation_tokens,
|
||||
)
|
||||
self.merge_stats(token_stats)
|
||||
if llm_response.provider_cost is not None:
|
||||
total_provider_cost = (
|
||||
total_provider_cost or 0.0
|
||||
) + llm_response.provider_cost
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
@@ -1446,6 +1601,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1466,6 +1622,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
@@ -1497,6 +1654,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
# All retries exhausted or user-error break: persist accumulated cost so
|
||||
# the executor can still charge/report the spend even on failure.
|
||||
if total_provider_cost is not None:
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost))
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
@@ -2011,6 +2172,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.lead_list)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
|
||||
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description=(
|
||||
"Database hostname or IP address. "
|
||||
"Treated as a secret to avoid leaking infrastructure details. "
|
||||
"Private/internal IPs are blocked (SSRF protection)."
|
||||
),
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
truncated: bool = SchemaField(
|
||||
description=(
|
||||
"True when the result set was capped by max_rows, "
|
||||
"indicating additional rows exist in the database"
|
||||
)
|
||||
)
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
("truncated", False),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
False,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected, truncated = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
yield "truncated", truncated
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield (
|
||||
"error",
|
||||
self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
),
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield (
|
||||
"error",
|
||||
(
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
|
||||
"LOAD",
|
||||
# MySQL REPLACE is INSERT-or-UPDATE; data modification
|
||||
"REPLACE",
|
||||
# ANSI MERGE (UPSERT) modifies data
|
||||
"MERGE",
|
||||
# MSSQL BULK INSERT loads external files into tables
|
||||
"BULK",
|
||||
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
|
||||
"EXEC",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
|
||||
# login_timeout). This bounds how long the driver waits to establish a TCP
|
||||
# connection to the database server. It is separate from the per-statement
|
||||
# timeout configured via SET commands inside _configure_session().
|
||||
_CONNECT_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# NaN / Infinity are not valid JSON numbers; serialize as strings.
|
||||
if value.is_nan() or value.is_infinite():
|
||||
return str(value)
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect.
|
||||
|
||||
Timeout limitations by database:
|
||||
|
||||
* **PostgreSQL** – ``statement_timeout`` reliably cancels any running
|
||||
statement (SELECT or DML) after the configured duration.
|
||||
* **MySQL** – ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
|
||||
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
|
||||
this hint; they rely on the server's ``wait_timeout`` /
|
||||
``interactive_timeout`` instead. There is no session-level setting in
|
||||
MySQL that reliably cancels long-running writes.
|
||||
* **MSSQL** – ``SET LOCK_TIMEOUT`` only limits how long the server waits
|
||||
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
|
||||
joins) that do not block on locks will *not* be cancelled. MSSQL has
|
||||
no session-level ``statement_timeout`` equivalent; the closest
|
||||
mechanism is Resource Governor (requires sysadmin configuration) or
|
||||
``CONTEXT_INFO``-based external monitoring.
|
||||
|
||||
Note: SQLite is not supported by this block. The ``_configure_session``
|
||||
function is a no-op for unrecognised dialect names, so an SQLite engine
|
||||
would skip all SET commands silently. The block's ``DatabaseType`` enum
|
||||
intentionally excludes SQLite.
|
||||
"""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
|
||||
# CPU-bound queries without lock contention are NOT cancelled.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a query inside an explicit transaction, returning results.
|
||||
|
||||
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
|
||||
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
|
||||
indicating that additional rows may exist in the result set.
|
||||
"""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
truncated = len(rows) == max_rows
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
try:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected, truncated
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(
|
||||
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
|
||||
)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
SubAgentRecursionError,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
@@ -57,7 +58,7 @@ class TestCheckRecursion:
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
with pytest.raises(SubAgentRecursionError):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
@@ -71,7 +72,7 @@ class TestCheckRecursion:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
with pytest.raises(SubAgentRecursionError):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
@@ -81,7 +82,7 @@ class TestCheckRecursion:
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(SubAgentRecursionError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
@@ -175,6 +176,29 @@ class TestRunValidation:
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dry_run_inherited_from_execution_context(self, block):
|
||||
"""execution_context.dry_run=True must be OR-ed into create_session dry_run
|
||||
so that nested AutoPilot sessions simulate even when input_data.dry_run=False.
|
||||
"""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"sess-dry",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-dry")
|
||||
|
||||
input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=False)
|
||||
ctx = _make_context()
|
||||
ctx.dry_run = True # outer execution is dry_run
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_called_once_with(ctx.user_id, dry_run=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
@@ -221,3 +245,171 @@ class TestBlockRegistration:
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recovery enqueue integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoveryEnqueue:
|
||||
"""Tests that run() enqueues orphaned sessions for recovery on failure."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_enqueued_on_transient_exception(self, block):
|
||||
"""A generic exception should trigger _enqueue_for_recovery."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("network error"))
|
||||
block.create_session = AsyncMock(return_value="sess-recover")
|
||||
|
||||
input_data = block.Input(prompt="do work", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
|
||||
with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue:
|
||||
mock_enqueue.return_value = None
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert "network error" in outputs.get("error", "")
|
||||
mock_enqueue.assert_awaited_once_with(
|
||||
"sess-recover",
|
||||
ctx.user_id,
|
||||
"do work",
|
||||
False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_not_enqueued_for_recursion_limit(self, block):
|
||||
"""Recursion limit errors are deliberate — no recovery enqueue."""
|
||||
block.execute_copilot = AsyncMock(
|
||||
side_effect=SubAgentRecursionError(
|
||||
"AutoPilot recursion depth limit reached (3). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
)
|
||||
block.create_session = AsyncMock(return_value="sess-rec-limit")
|
||||
|
||||
input_data = block.Input(prompt="recurse", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
|
||||
with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue:
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
mock_enqueue.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_not_enqueued_for_dry_run(self, block):
|
||||
"""dry_run=True sessions must not be enqueued (no real consumers)."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("transient"))
|
||||
block.create_session = AsyncMock(return_value="sess-dry-fail")
|
||||
|
||||
input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=True)
|
||||
ctx = _make_context()
|
||||
|
||||
with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue:
|
||||
mock_enqueue.return_value = None
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
# _enqueue_for_recovery is called with dry_run=True,
|
||||
# so the inner guard returns early without publishing to the queue.
|
||||
mock_enqueue.assert_awaited_once()
|
||||
positional = mock_enqueue.call_args_list[0][0]
|
||||
assert positional[3] is True # dry_run=True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_enqueue_failure_does_not_mask_original_error(self, block):
|
||||
"""If _enqueue_for_recovery itself raises, the original error is still yielded."""
|
||||
block.execute_copilot = AsyncMock(side_effect=ValueError("original"))
|
||||
block.create_session = AsyncMock(return_value="sess-enq-fail")
|
||||
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
|
||||
async def _failing_enqueue(*args, **kwargs):
|
||||
raise OSError("rabbitmq down")
|
||||
|
||||
with patch(
|
||||
"backend.blocks.autopilot._enqueue_for_recovery",
|
||||
side_effect=_failing_enqueue,
|
||||
):
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
# Original error must still be surfaced despite the enqueue failure
|
||||
assert outputs.get("error") == "original"
|
||||
assert outputs.get("session_id") == "sess-enq-fail"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_uses_dry_run_from_context(self, block):
|
||||
"""execution_context.dry_run=True is OR-ed into the dry_run arg."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("fail"))
|
||||
block.create_session = AsyncMock(return_value="sess-ctx-dry")
|
||||
|
||||
input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=False)
|
||||
ctx = _make_context()
|
||||
ctx.dry_run = True # outer execution is dry_run
|
||||
|
||||
with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue:
|
||||
mock_enqueue.return_value = None
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
mock_enqueue.assert_awaited_once()
|
||||
positional = mock_enqueue.call_args_list[0][0]
|
||||
assert positional[3] is True # dry_run=True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_uses_effective_prompt_with_system_context(self, block):
|
||||
"""When system_context is set, _enqueue_for_recovery receives the
|
||||
effective_prompt (system_context prepended) so the dedup check in
|
||||
maybe_append_user_message passes on replay."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("e2b timeout"))
|
||||
block.create_session = AsyncMock(return_value="sess-sys-ctx")
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="do work",
|
||||
system_context="Be concise.",
|
||||
max_recursion_depth=3,
|
||||
)
|
||||
ctx = _make_context()
|
||||
|
||||
with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue:
|
||||
mock_enqueue.return_value = None
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
mock_enqueue.assert_awaited_once()
|
||||
positional = mock_enqueue.call_args_list[0][0]
|
||||
assert positional[2] == "[System Context: Be concise.]\n\ndo work"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_cancelled_error_still_yields_error(self, block):
|
||||
"""CancelledError during _enqueue_for_recovery still yields the error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("e2b stall"))
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
async def _cancelled_enqueue(*args, **kwargs):
|
||||
raise asyncio.CancelledError
|
||||
|
||||
outputs = {}
|
||||
with patch(
|
||||
"backend.blocks.autopilot._enqueue_for_recovery",
|
||||
side_effect=_cancelled_enqueue,
|
||||
):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(
|
||||
block.Input(prompt="do work", max_recursion_depth=3),
|
||||
execution_context=_make_context(),
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# error must be yielded even when recovery raises CancelledError
|
||||
assert outputs.get("error") == "e2b stall"
|
||||
assert outputs.get("session_id") == "sess-cancel"
|
||||
|
||||
@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -46,6 +46,110 @@ class TestLLMStatsTracking:
|
||||
assert response.completion_tokens == 20
|
||||
assert response.response == "Test response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_anthropic_returns_cache_tokens(self):
|
||||
"""Test that llm_call returns cache read/creation tokens from Anthropic."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
anthropic_creds = APIKeyCredentials(
|
||||
id="test-anthropic-id",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-key"),
|
||||
title="Mock Anthropic key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
mock_content_block = MagicMock()
|
||||
mock_content_block.type = "text"
|
||||
mock_content_block.text = "Test anthropic response"
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.input_tokens = 15
|
||||
mock_usage.output_tokens = 25
|
||||
mock_usage.cache_read_input_tokens = 100
|
||||
mock_usage.cache_creation_input_tokens = 50
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [mock_content_block]
|
||||
mock_response.usage = mock_usage
|
||||
mock_response.stop_reason = "end_turn"
|
||||
|
||||
with (
|
||||
patch("anthropic.AsyncAnthropic") as mock_anthropic,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = ""
|
||||
mock_client = AsyncMock()
|
||||
mock_anthropic.return_value = mock_client
|
||||
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=anthropic_creds,
|
||||
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert isinstance(response, llm.LLMResponse)
|
||||
assert response.prompt_tokens == 15
|
||||
assert response.completion_tokens == 25
|
||||
assert response.cache_read_tokens == 100
|
||||
assert response.cache_creation_tokens == 50
|
||||
assert response.response == "Test anthropic response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_routes_through_openrouter_when_key_present(self):
|
||||
"""When open_router_api_key is set, Anthropic models route via OpenRouter."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
anthropic_creds = APIKeyCredentials(
|
||||
id="test-anthropic-id",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-key"),
|
||||
title="Mock Anthropic key",
|
||||
)
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "routed response"
|
||||
mock_choice.message.tool_calls = None
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 5
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = mock_usage
|
||||
|
||||
mock_create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch("openai.AsyncOpenAI") as mock_openai,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = "sk-or-test-key"
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = mock_create
|
||||
|
||||
await llm.llm_call(
|
||||
credentials=anthropic_creds,
|
||||
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify OpenAI client was used (not Anthropic SDK) and model was prefixed
|
||||
mock_openai.assert_called_once()
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
@@ -199,6 +303,139 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_accumulates_across_attempts(self):
|
||||
"""provider_cost accumulates across all retry attempts.
|
||||
|
||||
Each LLM call incurs a real cost, including failed validation attempts.
|
||||
The total cost is the sum of all attempts so no billed USD is lost.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# provider_cost accumulates across all attempts: $0.01 + $0.02 = $0.03
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.03)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_accumulated_in_stats(self):
|
||||
"""Cache read/creation tokens are tracked per-attempt and accumulated."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="tok123456">{"key1": "v1", "key2": "v2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
cache_read_tokens=20,
|
||||
cache_creation_tokens=8,
|
||||
reasoning=None,
|
||||
provider_cost=0.005,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=1,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="tok123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert block.execution_stats.cache_read_token_count == 20
|
||||
assert block.execution_stats.cache_creation_token_count == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_path_persists_accumulated_cost(self):
|
||||
"""When all retries are exhausted, accumulated provider_cost is preserved."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="not valid json at all",
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Both retry attempts each cost $0.01, total $0.02
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
@@ -488,6 +725,154 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == {"result": "test"}
|
||||
|
||||
|
||||
class TestAIConversationBlockValidation:
|
||||
"""Test that AIConversationBlock validates inputs before calling the LLM."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_and_empty_prompt_raises_error(self):
|
||||
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_with_prompt_succeeds(self):
|
||||
"""Empty messages but a non-empty prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "OK"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="Hello, how are you?",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
|
||||
"""Non-empty messages with no prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "response from conversation"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "response from conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_content_raises_error(self):
|
||||
"""Messages with empty content strings should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_whitespace_content_raises_error(self):
|
||||
"""Messages with whitespace-only content should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": " "}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_entry_raises_error(self):
|
||||
"""Messages list containing None should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[None],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_dict_raises_error(self):
|
||||
"""Messages list containing empty dict should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_content_raises_error(self):
|
||||
"""Messages with content=None should not crash with AttributeError."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": None}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
class TestAITextSummarizerValidation:
|
||||
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
|
||||
|
||||
@@ -839,3 +1224,295 @@ class TestLlmModelMissing:
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
|
||||
class TestExtractOpenRouterCost:
|
||||
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
|
||||
|
||||
def _mk_response(self, headers: dict | None):
|
||||
response = MagicMock()
|
||||
if headers is None:
|
||||
response._response = None
|
||||
else:
|
||||
raw = MagicMock()
|
||||
raw.headers = headers
|
||||
response._response = raw
|
||||
return response
|
||||
|
||||
def test_extracts_numeric_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "0.0042"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0042
|
||||
|
||||
def test_returns_none_when_header_missing(self):
|
||||
response = self._mk_response({})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_empty_string(self):
|
||||
response = self._mk_response({"x-total-cost": ""})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_non_numeric(self):
|
||||
response = self._mk_response({"x-total-cost": "not-a-number"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_no_response_attr(self):
|
||||
response = MagicMock(spec=[]) # no _response attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_is_none(self):
|
||||
response = self._mk_response(None)
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_has_no_headers(self):
|
||||
response = MagicMock()
|
||||
response._response = MagicMock(spec=[]) # no headers attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_zero_for_zero_cost(self):
|
||||
"""Zero-cost is a valid value (free tier) and must not become None."""
|
||||
response = self._mk_response({"x-total-cost": "0"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0
|
||||
|
||||
def test_returns_none_for_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_negative_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "-inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_nan(self):
|
||||
response = self._mk_response({"x-total-cost": "nan"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_negative_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "-0.005"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
|
||||
class TestAnthropicCacheControl:
|
||||
"""Verify that llm_call attaches cache_control to the system prompt block
|
||||
and to the last tool definition when calling the Anthropic API."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_openrouter_routing(self):
|
||||
"""Ensure tests exercise the direct-Anthropic path by suppressing the
|
||||
OpenRouter API key. Without this, a local .env with OPEN_ROUTER_API_KEY
|
||||
set would silently reroute all Anthropic calls through OpenRouter,
|
||||
bypassing the cache_control code under test."""
|
||||
with patch("backend.blocks.llm.settings") as mock_settings:
|
||||
mock_settings.secrets.open_router_api_key = ""
|
||||
yield mock_settings
|
||||
|
||||
def _make_anthropic_credentials(self) -> llm.APIKeyCredentials:
|
||||
from pydantic import SecretStr
|
||||
|
||||
return llm.APIKeyCredentials(
|
||||
id="test-anthropic-id",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-key"),
|
||||
title="Mock Anthropic key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_prompt_sent_as_block_with_cache_control(self):
|
||||
"""The system prompt is wrapped in a structured block with cache_control ephemeral."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(type="text", text="hello")]
|
||||
mock_resp.usage = MagicMock(input_tokens=5, output_tokens=3)
|
||||
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mock_resp
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.messages.create = fake_create
|
||||
|
||||
credentials = self._make_anthropic_credentials()
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm.LlmModel.CLAUDE_4_6_SONNET,
|
||||
prompt=[
|
||||
{"role": "system", "content": "You are an assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
system_arg = captured_kwargs.get("system")
|
||||
assert isinstance(system_arg, list), "system should be a list of blocks"
|
||||
assert len(system_arg) == 1
|
||||
block = system_arg[0]
|
||||
assert block["type"] == "text"
|
||||
assert block["text"] == "You are an assistant."
|
||||
assert block.get("cache_control") == {"type": "ephemeral"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_tool_gets_cache_control(self):
|
||||
"""cache_control is placed on the last tool in the Anthropic tools list."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(type="text", text="ok")]
|
||||
mock_resp.usage = MagicMock(input_tokens=10, output_tokens=5)
|
||||
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mock_resp
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.messages.create = fake_create
|
||||
|
||||
credentials = self._make_anthropic_credentials()
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_a",
|
||||
"description": "First tool",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_b",
|
||||
"description": "Second tool",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm.LlmModel.CLAUDE_4_6_SONNET,
|
||||
prompt=[
|
||||
{"role": "system", "content": "System."},
|
||||
{"role": "user", "content": "Do something"},
|
||||
],
|
||||
max_tokens=100,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
an_tools = captured_kwargs.get("tools")
|
||||
assert isinstance(an_tools, list)
|
||||
assert len(an_tools) == 2
|
||||
assert (
|
||||
an_tools[0].get("cache_control") is None
|
||||
), "Only last tool gets cache_control"
|
||||
assert an_tools[-1].get("cache_control") == {"type": "ephemeral"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tools_no_cache_control_on_tools(self):
|
||||
"""When there are no tools, the Anthropic call receives anthropic.NOT_GIVEN for tools."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(type="text", text="ok")]
|
||||
mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2)
|
||||
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mock_resp
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.messages.create = fake_create
|
||||
|
||||
credentials = self._make_anthropic_credentials()
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm.LlmModel.CLAUDE_4_6_SONNET,
|
||||
prompt=[
|
||||
{"role": "system", "content": "System."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
max_tokens=100,
|
||||
tools=None,
|
||||
)
|
||||
|
||||
import anthropic
|
||||
|
||||
tools_arg = captured_kwargs.get("tools")
|
||||
assert (
|
||||
tools_arg is anthropic.NOT_GIVEN
|
||||
), "Empty tools should pass anthropic.NOT_GIVEN sentinel"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_system_prompt_omits_system_key(self):
|
||||
"""When sysprompt is empty, the 'system' key must not be sent to Anthropic.
|
||||
|
||||
Anthropic rejects empty text blocks; the guard in llm_call must ensure
|
||||
the system argument is omitted entirely when no system messages are present.
|
||||
"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(type="text", text="ok")]
|
||||
mock_resp.usage = MagicMock(input_tokens=3, output_tokens=2)
|
||||
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mock_resp
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.messages.create = fake_create
|
||||
|
||||
credentials = self._make_anthropic_credentials()
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm.LlmModel.CLAUDE_4_6_SONNET,
|
||||
prompt=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
assert (
|
||||
"system" not in captured_kwargs
|
||||
), "system must be omitted when sysprompt is empty to avoid Anthropic 400"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_system_prompt_omits_system_key(self):
|
||||
"""Whitespace-only system content is treated as empty and omitted.
|
||||
|
||||
The guard in llm_call uses sysprompt.strip() so a prompt consisting of
|
||||
only whitespace should NOT reach the Anthropic API (it would be rejected
|
||||
as an empty text block).
|
||||
"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(type="text", text="ok")]
|
||||
mock_resp.usage = MagicMock(input_tokens=3, output_tokens=2)
|
||||
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mock_resp
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.messages.create = fake_create
|
||||
|
||||
credentials = self._make_anthropic_credentials()
|
||||
|
||||
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
||||
await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm.LlmModel.CLAUDE_4_6_SONNET,
|
||||
prompt=[
|
||||
{"role": "system", "content": " \n\t "},
|
||||
{"role": "user", "content": "Hi"},
|
||||
],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
assert (
|
||||
"system" not in captured_kwargs
|
||||
), "whitespace-only sysprompt must be omitted to avoid Anthropic 400"
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Must be AsyncMock because it is
|
||||
# an async method and is directly awaited in _execute_single_tool_with_manager.
|
||||
# Use a non-zero cost so the merge_stats branch is exercised.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
# Verify charge_node_usage was actually called for the successful
|
||||
# tool execution — this guards against regressions where the
|
||||
# post-execution tool charging is accidentally removed.
|
||||
assert mock_execution_processor.charge_node_usage.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
@@ -1074,6 +1084,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
@@ -1105,6 +1116,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
|
||||
@@ -306,6 +306,9 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
mock_response.raw_response = {"role": "assistant", "content": "test"}
|
||||
mock_response.prompt_tokens = 100
|
||||
mock_response.completion_tokens = 50
|
||||
mock_response.cache_read_tokens = 0
|
||||
mock_response.cache_creation_tokens = 0
|
||||
mock_response.provider_cost = None
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
@@ -638,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would
|
||||
# return a non-awaitable tuple and TypeError out, then be
|
||||
# silently swallowed by the orchestrator's catch-all.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(
|
||||
return_value=(0, 0)
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -211,6 +211,30 @@ class TestConvertRawResponseToDict:
|
||||
# A single dict is wrong — there are two distinct items
|
||||
pytest.fail("Expected a list of output items, got a single dict")
|
||||
|
||||
def test_responses_api_strips_status_from_function_call(self):
|
||||
"""Responses API function_call items have a 'status' field that OpenAI
|
||||
rejects when sent back as input ('Unknown parameter: input[N].status').
|
||||
It must be stripped before the item is stored in conversation history."""
|
||||
resp = _MockResponse(
|
||||
output=[_MockFunctionCall("my_tool", '{"x": 1}', call_id="call_xyz")]
|
||||
)
|
||||
result = _convert_raw_response_to_dict(resp)
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert (
|
||||
"status" not in item
|
||||
), f"'status' must be stripped from Responses API items: {item}"
|
||||
|
||||
def test_responses_api_strips_status_from_message(self):
|
||||
"""Responses API message items also carry 'status'; it must be stripped."""
|
||||
resp = _MockResponse(output=[_MockOutputMessage("Hello")])
|
||||
result = _convert_raw_response_to_dict(resp)
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert (
|
||||
"status" not in item
|
||||
), f"'status' must be stripped from Responses API items: {item}"
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# _get_tool_requests (lines 61-86)
|
||||
@@ -932,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
ep.execution_stats_lock = threading.Lock()
|
||||
ns = MagicMock(error=None)
|
||||
ep.on_node_execution = AsyncMock(return_value=ns)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would return a
|
||||
# non-awaitable tuple and TypeError out, then be silently swallowed by
|
||||
# the orchestrator's catch-all.
|
||||
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user