Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/task-decomposition-copilot

This commit is contained in:
anvyle
2026-04-09 12:27:23 +02:00
146 changed files with 14775 additions and 563 deletions

View File

@@ -0,0 +1,709 @@
---
name: orchestrate
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
user-invocable: true
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
metadata:
author: autogpt-team
version: "6.0.0"
---
# Orchestrate — Agent Fleet Supervisor
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
## Scripts
```bash
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
STATE_FILE=~/.claude/orchestrator-state.json
```
| Script | Purpose |
|---|---|
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
| `status.sh` | Print fleet status + live pane commands |
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
| `classify-pane.sh WINDOW` | Classify one pane state |
## Supervision model
```
Orchestrating Claude (this Claude session — IS the supervisor)
└── Reads pane output, checks CI, intervenes with targeted guidance
run-loop.sh (separate tmux window, every 30s)
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
```
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
## Checkpoint protocol
Agents output checkpoints as they complete each required step:
```
CHECKPOINT:<step-name>
```
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
## Worktree lifecycle
```text
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
CHECKPOINT:<step> (as steps complete)
ORCHESTRATOR:DONE
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
state → "done", notify, window KEPT OPEN
user/orchestrator explicitly requests recycle
recycle-agent.sh → spare/N (free again)
```
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
**To resume a done or crashed session:**
```bash
# Resume by stored session ID (preferred — exact session, full context)
claude --resume SESSION_ID --permission-mode bypassPermissions
# Or resume most recent session in that worktree directory
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
```
**To manually recycle when ready:**
```bash
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
# Then update state:
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## State file (`~/.claude/orchestrator-state.json`)
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp``mv`).
```json
{
"active": true,
"tmux_session": "autogpt1",
"idle_threshold_seconds": 300,
"loop_window": "autogpt1:5",
"repo": "Significant-Gravitas/AutoGPT",
"discord_webhook": "https://discord.com/api/webhooks/...",
"last_poll_at": 0,
"agents": [
{
"window": "autogpt1:3",
"worktree": "AutoGPT6",
"worktree_path": "/path/to/AutoGPT6",
"spare_branch": "spare/6",
"branch": "feat/my-feature",
"objective": "Implement X and open a PR",
"pr_number": "12345",
"session_id": "550e8400-e29b-41d4-a716-446655440000",
"steps": ["pr-address", "pr-test"],
"checkpoints": ["pr-address"],
"state": "running",
"last_output_hash": "",
"last_seen_at": 0,
"spawned_at": 0,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}
]
}
```
Top-level optional fields:
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
Per-agent fields:
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
## Serial /pr-test rule
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
You (the orchestrating Claude) own the test queue:
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
4. Feed results back to the relevant agent via `tmux send-keys`:
```bash
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
5. Wait for CI to confirm green before marking the agent done.
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
## Session restore (tested and confirmed)
Agent sessions are saved to disk. To restore a closed or crashed session:
```bash
# If session_id is in state (preferred):
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
# If no session_id (use --continue for most recent session in that directory):
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
```
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
```bash
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
'(.agents[] | select(.window == $old)).window = $new' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## Intent → action mapping
Match the user's message to one of these intents:
| The user says something like… | What to do |
|---|---|
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
When the intent is ambiguous, show capacity first and ask what tasks to run.
## Spawning agents
### 1. Resolve tmux session
```bash
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
```
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
### 2. Show available capacity
```bash
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
```
### 3. Collect tasks from the user
For each task, gather:
- **objective** — what to do (e.g. "implement feature X and open a PR")
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
### 4. Spawn one agent per task
```bash
# Get ordered list of spare worktrees
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
# For each task, take the next spare line:
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
# With PR number and required steps:
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
# Without PR (new work):
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
```
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
```bash
# Derive repo from git remote (used by verify-complete.sh + supervisor)
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
jq -n \
--arg session "$SESSION" \
--arg repo "$REPO" \
--argjson threshold 300 \
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
> ~/.claude/orchestrator-state.json
```
Optionally add a Discord webhook for completion notifications:
```bash
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
### 5. Start the mechanical babysitter
```bash
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
### 6. Begin supervising directly in this conversation
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
## Adding an agent
Find the next spare worktree, then spawn and append to state — same as steps 24 above but for a single task. If no spare worktrees are available, tell the user.
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
### Poll loop mechanism
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
1. Start each poll with `run_in_background: true` + a sleep before the work:
```bash
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
# + similar for each active window
```
2. When the background job notifies you, read the pane output and take action.
3. Immediately schedule the next background poll — this keeps the loop alive.
4. Stop scheduling when all agents are done/escalated.
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
### Each poll: what to check
```bash
# 1. Read state
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
# 2. For each running/stuck/idle agent, capture pane
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
```
For each agent, decide:
| What you see | Action |
|---|---|
| Spinner / tools running | Do nothing — agent is working |
| Idle `` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
| Stuck in error loop | Send targeted fix with exact error + solution |
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
| GitHub abuse rate limit error | Nudge: "Wait 60 seconds then continue posting replies with sleep 3 between each" |
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
| `ORCHESTRATOR:DONE` in output | Query GraphQL for actual unresolved count. If >0, re-brief. If 0, run `verify-complete.sh` |
**Poll all windows from state, not from memory.** Before each poll, run:
```bash
jq -r '.agents[] | select(.state | test("running|idle|stuck|waiting_approval|pending_evaluation")) | .window' ~/.claude/orchestrator-state.json
```
and capture every window listed. If you manually added a window outside spawn-agent.sh, ensure it's in the state file first.
### RUNNING count includes waiting_approval agents
The `RUNNING` count from run-loop.sh includes agents in `waiting_approval` state (they match the regex `running|stuck|waiting_approval|idle`). This means a fleet that is only `waiting_approval` still shows RUNNING > 0 in the log — it does **not** mean agents are actively working.
When you see `RUNNING > 0` in the run-loop log but suspect agents are actually blocked, check state directly:
```bash
jq '.agents[] | {window, state, worktree}' ~/.claude/orchestrator-state.json
```
A count of `running=1 waiting=1` in the log actually means one agent is waiting for approval — the orchestrator should check and approve, not wait.
### State file staleness recovery
The state file is written by scripts but can drift from reality when windows are closed, sessions expire, or the orchestrator restarts across conversations.
**Signs of stale state:**
- `loop_window` points to a window that no longer exists in the tmux session
- An agent's `state` is `running` but tmux window is closed or shows a shell prompt (not claude)
- `last_seen_at` is hours old but state still says `running`
**Recovery steps:**
1. **Verify actual tmux windows:**
```bash
tmux list-windows -t SESSION -F '#{window_index}: #{window_name} (#{pane_current_command})'
```
2. **Cross-reference with state file:**
```bash
jq -r '.agents[] | "\(.window) \(.state) \(.worktree)"' ~/.claude/orchestrator-state.json
```
3. **Fix stale entries:**
```bash
# Agent window closed — mark idle so run-loop.sh will restart it
jq --arg w "SESSION:WIN" '(.agents[] | select(.window==$w)).state = "idle"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
# loop_window gone — kill the stale reference, then restart run-loop.sh
jq '.loop_window = null' ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
4. **After any state repair, re-run `status.sh` to confirm coherence before resuming supervision.**
### Strict ORCHESTRATOR:DONE gate
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
```bash
SKILLS_DIR=~/.claude/orchestrator/scripts
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
```
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
### Re-brief a stalled agent
**Before sending any nudge, verify the pane is at an idle prompt.** Sending text into a still-processing pane produces stuck `[Pasted text +N lines]` that the agent never sees.
Check:
```bash
tmux capture-pane -t SESSION:WIN -p 2>/dev/null | tail -5
```
If the last line shows a spinner (✳✽✢✶·), `Running…`, or no `` — wait 1015s and check again before sending.
```bash
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
## Self-recovery protocol (agents)
spawn-agent.sh automatically includes this instruction in every objective:
> If your context compacts and you lose track of what to do, run:
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
## Passing images and screenshots to agents
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
1. **Save the image to a temp file** with a stable path:
```bash
# If the user drags in a screenshot or you receive a file path:
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
```
2. **Reference the path in the objective string**:
```bash
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
```
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
---
## Orchestrator final evaluation (YOU decide, not the script)
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
```
- [ ] /pr-test PR #12636 — fix copilot retry logic
- [ ] /pr-test PR #12699 — builder chat panel
```
Run one at a time. Check off as you go.
```
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
```
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
```
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
Context: This PR implements <objective from state file>. Key files: <list>.
Please verify: <specific behaviors to check>.
```
Only one `/pr-test` at a time — they share ports and DB.
### /pr-test result evaluation
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
| `/pr-test` result | Action |
|---|---|
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
| Any headline scenario **FAIL** | Re-brief the agent immediately |
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
**When any headline scenario is PARTIAL or FAIL:**
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
2. Re-brief the agent with the specific scenario that failed and what was missing:
```bash
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
3. Set state back to `running`:
```bash
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
### 2. Do your own evaluation
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
4. **Check the PR description** — title, summary, test plan complete?
### 3. Decide
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
```bash
# Mark done after your positive evaluation:
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## When to stop the fleet
Stop the fleet (`active = false`) when **all** of the following are true:
| Check | How to verify |
|---|---|
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
| No agents are `escalated` without human review | If any are escalated, surface to user first |
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
```bash
# Graceful stop
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
```
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
## tmux send-keys pattern
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
```bash
# CORRECT — text then Enter separately
tmux send-keys -t "$WINDOW" "your long message here"
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# WRONG — Enter may fire before text is buffered
tmux send-keys -t "$WINDOW" "your long message here" Enter
```
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
---
## Protected worktrees
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
| Worktree | Protected branch | Why |
|---|---|---|
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
---
## Thread resolution integrity (critical)
**Agents MUST NOT resolve review threads via GraphQL unless a real code fix has been committed and pushed first.**
This is the most common failure mode: agents call `resolveReviewThread` to make unresolved counts drop without actually fixing anything. This produces a false "done" signal that gets past verify-complete.sh.
**The only valid resolution sequence:**
1. Read the thread and understand what it's asking
2. Make the actual code change
3. `git commit` and `git push`
4. Reply to the thread with the commit SHA (e.g. "Fixed in `abc1234`")
5. THEN call `resolveReviewThread`
**The supervisor must verify actual thread counts via GraphQL** — never trust an agent's claim of "0 unresolved." After any agent's ORCHESTRATOR:DONE, always run:
```bash
# Step 1: get total count
TOTAL=$(gh api graphql -f query='{ repository(owner: "OWNER", name: "REPO") { pullRequest(number: PR) { reviewThreads { totalCount } } } }' \
| jq '.data.repository.pullRequest.reviewThreads.totalCount')
echo "Total threads: $TOTAL"
# Step 2: paginate all pages and count unresolved
CURSOR=""; UNRESOLVED=0
while true; do
AFTER=${CURSOR:+", after: \"$CURSOR\""}
PAGE=$(gh api graphql -f query="{ repository(owner: \"OWNER\", name: \"REPO\") { pullRequest(number: PR) { reviewThreads(first: 100${AFTER}) { pageInfo { hasNextPage endCursor } nodes { isResolved } } } } }")
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
[ "$HAS_NEXT" = "false" ] && break
done
echo "Unresolved: $UNRESOLVED"
```
If unresolved > 0, the agent is NOT done — re-brief with the actual count and the rule.
**Include this in every agent objective:**
> IMPORTANT: Do NOT resolve any review thread via GraphQL unless the code fix is committed and pushed first. Fix the code → commit → push → reply with SHA → then resolve. Never resolve without a real commit. "Accepted" or "Acknowledged" replies are NOT resolutions — only real commits qualify.
### Detecting fake resolutions
When an agent claims "0 unresolved threads", query GitHub GraphQL yourself and also inspect how each thread was resolved. A resolved thread whose last comment is `"Acknowledged"`, `"Same as above"`, `"Accepted trade-off"`, or `"Deferred"` — with no commit SHA — is a fake resolution.
To spot these, paginate all pages and collect resolved threads with missing SHA links:
```bash
# Paginate all pages — first:100 misses threads beyond page 1 on large PRs
CURSOR=""; FAKE_RESOLUTIONS="[]"
while true; do
AFTER=${CURSOR:+", after: \"$CURSOR\""}
PAGE=$(gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: PR_NUMBER) {
reviewThreads(first: 100${AFTER}) {
pageInfo { hasNextPage endCursor }
nodes {
isResolved
comments(last: 1) {
nodes { body author { login } }
}
}
}
}
}
}")
PAGE_FAKES=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[]
| select(.isResolved == true)
| {body: .comments.nodes[0].body[:120], author: .comments.nodes[0].author.login}
| select(.body | test("Fixed in|Removed in|Addressed in") | not)]')
FAKE_RESOLUTIONS=$(echo "$FAKE_RESOLUTIONS $PAGE_FAKES" | jq -s 'add')
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
[ "$HAS_NEXT" = "false" ] && break
done
echo "$FAKE_RESOLUTIONS"
```
Any resolved thread whose last comment does NOT contain `"Fixed in"`, `"Removed in"`, or `"Addressed in"` (with a commit link) should be investigated — either the agent falsely resolved it, or it was a genuine false positive that needs explanation.
## GitHub abuse rate limits
Two distinct rate limits exist with different recovery times:
| Error | HTTP status | Cause | Recovery |
|---|---|---|---|
| `{"code":"abuse"}` in body | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **23 minutes**. 60s is often not enough. |
| `API rate limit exceeded` | 429 | Primary rate limit — too many read calls per hour | Wait until `X-RateLimit-Reset` timestamp |
**Prevention:** Agents must add `sleep 3` between individual thread reply API calls. For >20 unresolved threads, increase to `sleep 5`.
If you see a 403 `abuse` error from an agent's pane:
1. Nudge the agent: `"You hit a GitHub secondary rate limit (403). Stop all API writes. Wait 2 minutes, then resume with sleep 3 between each thread reply."`
2. Do NOT nudge again during the 2-minute wait — a second nudge restarts the clock.
Add this to agent briefings when there are >20 unresolved threads:
> Post replies with `sleep 3` between each reply. If you hit a 403 abuse error, wait 2 minutes (not 60s — secondary limits take longer to clear) then continue.
## Key rules
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
5. **Always `--permission-mode bypassPermissions`** on every spawn
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
7. **Atomic state writes** — always write to `.tmp` then `mv`
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
16. **Poll ALL windows from state file** — never hardcode window count. Derive active windows dynamically: `jq -r '.agents[] | select(.state | test("running|idle|stuck")) | .window' ~/.claude/orchestrator-state.json`. If you added a window mid-session outside spawn-agent.sh, add it to the state file immediately.
20. **Orchestrator handles its own approvals** — when spawning a subagent to make edits (SKILL.md, scripts, config), review the diff yourself and approve/reject without surfacing it to the user. The user should never have to open a file to check the orchestrator's work. Use the Agent tool with `subagent_type: general-purpose` for drafting, then verify the result yourself before considering the task done.
17. **Update state file on re-task** — whenever an agent is re-tasked mid-session (objective changes, new PR assigned), update the state file record immediately so objectives stay accurate for re-briefing after compaction.
18. **No GraphQL resolveReviewThread without a commit** — see Thread resolution integrity above. This is rule #1 for pr-address work.
19. **Verify thread counts yourself** — after any agent claims "0 unresolved threads", query GitHub GraphQL directly before accepting. Never trust the agent's self-report.

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
#
# Usage: capacity.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree of the current git repo.
#
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
set -euo pipefail
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
echo "=== Available (spare) worktrees ==="
if [ -n "$REPO_ROOT" ]; then
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
else
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
fi
if [ -z "$SPARE" ]; then
echo " (none)"
else
while IFS= read -r line; do
[ -z "$line" ] && continue
echo "$line"
done <<< "$SPARE"
fi
echo ""
echo "=== In-use worktrees ==="
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
"$STATE_FILE" 2>/dev/null || echo "")
if [ -n "$IN_USE" ]; then
echo "$IN_USE"
else
echo " (none)"
fi
else
echo " (no active state file)"
fi

View File

@@ -0,0 +1,85 @@
#!/usr/bin/env bash
# classify-pane.sh — Classify the current state of a tmux pane
#
# Usage: classify-pane.sh <tmux-target>
# tmux-target: e.g. "work:0", "work:1.0"
#
# Output (stdout): JSON object:
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
#
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
set -euo pipefail
TARGET="${1:-}"
if [ -z "$TARGET" ]; then
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
exit 1
fi
# Validate tmux target format: session:window or session:window.pane
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
exit 1
fi
# Check session exists (use %%:* to extract session name from session:window)
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
exit 1
fi
# Get the current foreground command in the pane
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
| grep -v '^[[:space:]]*$' || true)
# --- Check: explicit completion marker ---
# Must be on its own line (not buried in the objective text sent at spawn time).
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
exit 0
fi
# --- Check: Claude Code approval prompt patterns ---
LAST_40=$(echo "$CLEAN" | tail -40)
APPROVAL_PATTERNS=(
"Do you want to proceed"
"Do you want to make this"
"\\[y/n\\]"
"\\[Y/n\\]"
"\\[n/Y\\]"
"Proceed\\?"
"Allow this command"
"Run bash command"
"Allow bash"
"Would you like"
"Press enter to continue"
"Esc to cancel"
)
for pattern in "${APPROVAL_PATTERNS[@]}"; do
if echo "$LAST_40" | grep -qiE "$pattern"; then
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
exit 0
fi
done
# --- Check: shell prompt (claude has exited) ---
# If the foreground process is a shell (not claude/node), the agent has exited
case "$PANE_CMD" in
zsh|bash|fish|sh|dash|tcsh|ksh)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
exit 0
;;
esac
# Agent is still running (claude/node/python is the foreground process)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
exit 0

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
# find-spare.sh — list worktrees on spare/N branches (free to use)
#
# Usage: find-spare.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree containing the current git repo.
#
# Output (stdout): one line per available worktree: "PATH BRANCH"
# e.g.: /Users/me/Code/AutoGPT3 spare/3
set -euo pipefail
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
if [ -z "$REPO_ROOT" ]; then
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
exit 1
fi
git -C "$REPO_ROOT" worktree list --porcelain \
| awk '
/^worktree / { path = substr($0, 10) }
/^branch / { branch = substr($0, 8); print path " " branch }
' \
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
| sed 's|refs/heads/||'

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# notify.sh — send a fleet notification message
#
# Delivery order (first available wins):
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
# 2. macOS notification center — osascript (silent fail if unavailable)
# 3. Stdout only
#
# Usage: notify.sh MESSAGE
# Exit: always 0 (notification failure must not abort the caller)
MESSAGE="${1:-}"
[ -z "$MESSAGE" ] && exit 0
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# --- Resolve Discord webhook ---
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
fi
# --- Discord delivery ---
if [ -n "$WEBHOOK" ]; then
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
curl -s -X POST "$WEBHOOK" \
-H "Content-Type: application/json" \
-d "$PAYLOAD" > /dev/null 2>&1 || true
fi
# --- macOS notification center (silent if not macOS or osascript missing) ---
if command -v osascript &>/dev/null 2>&1; then
# Escape single quotes for AppleScript
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
fi
# Always print to stdout so run-loop.sh logs it
echo "$MESSAGE"
exit 0

View File

@@ -0,0 +1,257 @@
#!/usr/bin/env bash
# poll-cycle.sh — Single orchestrator poll cycle
#
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
# and outputs a JSON array of actions for Claude to take.
#
# Usage: poll-cycle.sh
# Output (stdout): JSON array of action objects
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
# "worktree": "...", "objective": "...", "reason": "..." }]
#
# The state file is updated in-place (atomic write via .tmp).
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
# Cross-platform md5: always outputs just the hex digest
md5_hash() {
if command -v md5sum &>/dev/null; then
md5sum | awk '{print $1}'
else
md5 | awk '{print $NF}'
fi
}
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
# Ensure state file exists
if [ ! -f "$STATE_FILE" ]; then
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
fi
# Validate JSON upfront before any jq reads that run under set -e.
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
# abort the script at the ACTIVE read below without emitting any JSON output.
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
echo "[]"
exit 0
fi
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
if [ "$ACTIVE" != "true" ]; then
echo "[]"
exit 0
fi
NOW=$(date +%s)
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
ACTIONS="[]"
UPDATED_AGENTS="[]"
# Read agents as newline-delimited JSON objects.
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
# so we suppress that exit code and separately validate the file is well-formed JSON.
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
fi
echo "[]"
exit 0
fi
if [ -z "$AGENTS_JSON" ]; then
echo "[]"
exit 0
fi
while IFS= read -r agent; do
[ -z "$agent" ] && continue
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
WINDOW=$(echo "$agent" | jq -r '.window // ""')
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
STATE=$(echo "$agent" | jq -r '.state // "running"')
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
# Validate window format to prevent tmux target injection.
# Allow session:window (numeric or named) and session:window.pane
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo "Skipping agent with invalid window value: $WINDOW" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Pass-through terminal-state agents
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Classify pane.
# classify-pane.sh always emits JSON before exit (even on error), so using
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
# Use "|| true" inside the substitution so set -euo pipefail does not abort
# the poll cycle when classify exits with a non-zero status code.
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
# --- Checkpoint tracking ---
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
if [ -n "$RAW" ]; then
FOUND_CPS=$(echo "$RAW" \
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
| sed 's/CHECKPOINT://' \
| sort -u \
| jq -R . | jq -s . 2>/dev/null || echo "[]")
NEW_CHECKPOINTS_JSON=$(jq -n \
--argjson existing "$EXISTING_CPS" \
--argjson found "$FOUND_CPS" \
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
fi
# Compute content hash for stuck-detection (only for running agents)
CURRENT_HASH=""
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
fi
NEW_STATE="$STATE"
NEW_IDLE_SINCE="$IDLE_SINCE"
NEW_REVISION_COUNT="$REVISION_COUNT"
ACTION="none"
REASON="$PANE_REASON"
case "$PANE_STATE" in
complete)
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
NEW_STATE="pending_evaluation"
ACTION="complete" # run-loop logs it but takes no action
;;
waiting_approval)
NEW_STATE="waiting_approval"
ACTION="approve"
;;
idle)
# Agent process has exited — needs restart
NEW_STATE="idle"
ACTION="kick"
REASON="agent exited (shell is foreground)"
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
fi
;;
running)
# Clear idle_since only when transitioning from idle (agent was kicked and
# restarted). Do NOT reset for stuck — idle_since must persist across polls
# so STUCK_DURATION can accumulate and trigger escalation.
# Also update the local IDLE_SINCE so the hash-stability check below uses
# the reset value on this same poll, not the stale kick timestamp.
if [[ "$STATE" == "idle" ]]; then
NEW_IDLE_SINCE=0
IDLE_SINCE=0
fi
# Check if hash has been stable (agent may be stuck mid-task)
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
NEW_IDLE_SINCE=$NOW
else
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
else
NEW_STATE="stuck"
ACTION="kick"
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
fi
fi
fi
else
# Only reset the idle timer when we have a valid hash comparison (pane
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
# preserve existing timers so stuck detection is not inadvertently reset.
if [ -n "$CURRENT_HASH" ]; then
NEW_STATE="running"
NEW_IDLE_SINCE=0
fi
fi
;;
error)
REASON="classify error: $PANE_REASON"
;;
esac
# Build updated agent record (ensure idle_since and revision_count are numeric)
# Use || true on each jq call so a malformed field skips this agent rather than
# aborting the entire poll cycle under set -e.
UPDATED_AGENT=$(echo "$agent" | jq \
--arg state "$NEW_STATE" \
--arg hash "$CURRENT_HASH" \
--argjson now "$NOW" \
--arg idle_since "$NEW_IDLE_SINCE" \
--arg revision_count "$NEW_REVISION_COUNT" \
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
'.state = $state
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
| .last_seen_at = $now
| .idle_since = ($idle_since | tonumber)
| .revision_count = ($revision_count | tonumber)
| .checkpoints = $checkpoints' 2>/dev/null) || {
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
}
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
# Add action if needed
if [ "$ACTION" != "none" ]; then
ACTION_OBJ=$(jq -n \
--arg window "$WINDOW" \
--arg action "$ACTION" \
--arg state "$NEW_STATE" \
--arg worktree "$WORKTREE" \
--arg objective "$OBJECTIVE" \
--arg reason "$REASON" \
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
fi
done <<< "$AGENTS_JSON"
# Atomic state file update
jq --argjson agents "$UPDATED_AGENTS" \
--argjson now "$NOW" \
'.agents = $agents | .last_poll_at = $now' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
echo "$ACTIONS"

View File

@@ -0,0 +1,32 @@
#!/usr/bin/env bash
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
#
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
# WINDOW — tmux target, e.g. autogpt1:3
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — branch to restore, e.g. spare/6
#
# Stdout: one status line
set -euo pipefail
if [ $# -lt 3 ]; then
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
exit 1
fi
WINDOW="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
# Kill the tmux window (ignore error — may already be gone)
tmux kill-window -t "$WINDOW" 2>/dev/null || true
# Restore to spare branch: abort any in-progress operation, then clean
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
echo "Recycled: $(basename "$WORKTREE_PATH")$SPARE_BRANCH (window $WINDOW closed)"

View File

@@ -0,0 +1,215 @@
#!/usr/bin/env bash
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
#
# Handles ONLY two things that need no intelligence:
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
#
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
# marking done, deciding to close windows — is the orchestrating Claude's job.
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
# the orchestrator's background poll loop handles it from there.
#
# Usage: run-loop.sh
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
set -euo pipefail
# Copy scripts to a stable location outside the repo so they survive branch
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
# current branch).
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
mkdir -p "$STABLE_SCRIPTS_DIR"
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# Adaptive polling: starts at base interval, backs off up to POLL_IDLE_MAX when
# no agents need attention, resets on any activity or waiting_approval state.
POLL_INTERVAL="${POLL_INTERVAL:-30}"
POLL_IDLE_MAX=${POLL_IDLE_MAX:-300}
POLL_CURRENT=$POLL_INTERVAL
# ---------------------------------------------------------------------------
# update_state WINDOW FIELD VALUE
# ---------------------------------------------------------------------------
update_state() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --arg v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
update_state_int() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
agent_field() {
jq -r --arg w "$1" --arg f "$2" \
'.agents[] | select(.window == $w) | .[$f] // ""' \
"$STATE_FILE" 2>/dev/null
}
# ---------------------------------------------------------------------------
# wait_for_prompt WINDOW — wait up to 60s for Claude's prompt
# ---------------------------------------------------------------------------
wait_for_prompt() {
local window="$1"
for i in $(seq 1 60); do
local cmd pane
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
if echo "$pane" | grep -q "Enter to confirm"; then
tmux send-keys -t "$window" Down Enter; sleep 2; continue
fi
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "" && return 0
sleep 1
done
return 1 # timed out
}
# ---------------------------------------------------------------------------
# wait_for_claude_idle WINDOW — wait up to 30s for Claude to reach idle prompt
# (no spinner or busy indicator visible in the last 3 lines of pane output)
# Returns 0 when idle, 1 on timeout.
# ---------------------------------------------------------------------------
wait_for_claude_idle() {
local window="$1"
local timeout="${2:-30}"
local elapsed=0
while (( elapsed < timeout )); do
local cmd pane pane_tail
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
pane_tail=$(echo "$pane" | tail -3)
# Check full pane (not just tail) — 'Enter to confirm' dialog can scroll above last 3 lines.
# Do NOT reset elapsed — resetting allows an infinite loop if the dialog never clears.
if echo "$pane" | grep -q "Enter to confirm"; then
tmux send-keys -t "$window" Down Enter
sleep 2; (( elapsed += 2 )); continue
fi
# Must be running under node (Claude is live)
if [[ "$cmd" == "node" ]]; then
# Idle: prompt visible AND no spinner/busy text in last 3 lines
if echo "$pane_tail" | grep -q "" && \
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
return 0
fi
fi
sleep 2
(( elapsed += 2 ))
done
return 1 # timed out
}
# ---------------------------------------------------------------------------
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
# ---------------------------------------------------------------------------
handle_kick() {
local window="$1" state="$2"
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
local worktree_path session_id
worktree_path=$(agent_field "$window" "worktree_path")
session_id=$(agent_field "$window" "session_id")
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
# Wait for the shell prompt before typing — avoids sending into a still-draining pane
wait_for_claude_idle "$window" 30 \
|| echo "[$(date +%H:%M:%S)] KICK WARNING $window — pane still busy before resume, sending anyway"
# Resume the exact session so the agent retains full context — no need to re-send objective
if [ -n "$session_id" ]; then
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
else
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
fi
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for "
}
# ---------------------------------------------------------------------------
# handle_approve WINDOW — auto-approve dialogs that need no judgment
# ---------------------------------------------------------------------------
handle_approve() {
local window="$1"
local pane_tail
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
# Settings error dialog at startup
if echo "$pane_tail" | grep -q "Enter to confirm"; then
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
tmux send-keys -t "$window" Down Enter
return
fi
# Numbered-option dialog (e.g. "Do you want to make this edit?")
# is already on option 1 (Yes) — Enter confirms it
if echo "$pane_tail" | grep -qE "\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
tmux send-keys -t "$window" "" Enter
return
fi
# y/n prompt for safe operations
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
tmux send-keys -t "$window" "y" Enter
return
fi
# Anything else — supervisor handles it, just log
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
}
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll ${POLL_INTERVAL}s→${POLL_IDLE_MAX}s adaptive)"
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
echo "---"
while true; do
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
echo "[$(date +%H:%M:%S)] active=false — exiting."
exit 0
fi
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
KICKED=0; DONE=0
while IFS= read -r action; do
[ -z "$action" ] && continue
WINDOW=$(echo "$action" | jq -r '.window // ""')
ACTION=$(echo "$action" | jq -r '.action // ""')
STATE=$(echo "$action" | jq -r '.state // ""')
case "$ACTION" in
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
approve) handle_approve "$WINDOW" || true ;;
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
esac
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
"$STATE_FILE" 2>/dev/null || echo 0)
# Adaptive backoff: reset to base on activity or waiting_approval agents; back off when truly idle
WAITING=$(jq '[.agents[] | select(.state == "waiting_approval")] | length' "$STATE_FILE" 2>/dev/null || echo 0)
if (( KICKED > 0 || DONE > 0 || WAITING > 0 )); then
POLL_CURRENT=$POLL_INTERVAL
else
POLL_CURRENT=$(( POLL_CURRENT + POLL_CURRENT / 2 + 1 ))
(( POLL_CURRENT > POLL_IDLE_MAX )) && POLL_CURRENT=$POLL_IDLE_MAX
fi
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled (next in ${POLL_CURRENT}s)"
sleep "$POLL_CURRENT"
done

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env bash
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
#
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
# SESSION — tmux session name, e.g. autogpt1
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
# OBJECTIVE — task description sent to the agent
# PR_NUMBER — (optional) GitHub PR number for completion verification
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
#
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
# Exit non-zero on failure.
set -euo pipefail
if [ $# -lt 5 ]; then
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
exit 1
fi
SESSION="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
NEW_BRANCH="$4"
OBJECTIVE="$5"
PR_NUMBER="${6:-}"
STEPS=("${@:7}")
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# Generate a stable session ID so this agent's Claude session can always be resumed:
# claude --resume $SESSION_ID --permission-mode bypassPermissions
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
# Create (or switch to) the task branch
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
# Open a new named tmux window; capture its numeric index
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
WINDOW="${SESSION}:${WIN_IDX}"
# Append the initial agent record to the state file so subsequent jq updates find it.
# This must happen before the pr_number/steps update below.
if [ -f "$STATE_FILE" ]; then
NOW=$(date +%s)
jq --arg window "$WINDOW" \
--arg worktree "$WORKTREE_NAME" \
--arg worktree_path "$WORKTREE_PATH" \
--arg spare_branch "$SPARE_BRANCH" \
--arg branch "$NEW_BRANCH" \
--arg objective "$OBJECTIVE" \
--arg session_id "$SESSION_ID" \
--argjson now "$NOW" \
'.agents += [{
"window": $window,
"worktree": $worktree,
"worktree_path": $worktree_path,
"spare_branch": $spare_branch,
"branch": $branch,
"objective": $objective,
"session_id": $session_id,
"state": "running",
"checkpoints": [],
"last_output_hash": "",
"last_seen_at": $now,
"spawned_at": $now,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
# The agent record was appended above so the jq select now finds it.
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
if [ "${#STEPS[@]}" -gt 0 ]; then
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
else
STEPS_JSON='[]'
fi
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Launch claude with a stable session ID so it can always be resumed after a crash:
# claude --resume SESSION_ID --permission-mode bypassPermissions
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
# wait_for_claude_idle — poll until the pane shows idle with no spinner in the last 3 lines.
# Returns 0 when idle, 1 on timeout.
_wait_idle() {
local window="$1" timeout="${2:-60}" elapsed=0
while (( elapsed < timeout )); do
local cmd pane_tail
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
pane_tail=$(echo "$pane" | tail -3)
# Check full pane (not just tail) — 'Enter to confirm' dialog can appear above the last 3 lines
if echo "$pane" | grep -q "Enter to confirm"; then
tmux send-keys -t "$window" Down Enter
sleep 2; (( elapsed += 2 )); continue
fi
if [[ "$cmd" == "node" ]] && \
echo "$pane_tail" | grep -q "" && \
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
return 0
fi
sleep 2; (( elapsed += 2 ))
done
return 1
}
# Wait up to 60s for claude to be fully interactive and idle ( visible, no spinner).
if ! _wait_idle "$WINDOW" 60; then
echo "[spawn-agent] WARNING: timed out waiting for idle prompt on $WINDOW — sending objective anyway" >&2
fi
# Send the task. Split text and Enter — if combined, Enter can fire before the string
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# Only output the window address — nothing else (callers parse this)
echo "$WINDOW"

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# status.sh — print orchestrator status: state file summary + live tmux pane commands
#
# Usage: status.sh
# Reads: ~/.claude/orchestrator-state.json
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "No orchestrator state found at $STATE_FILE"
exit 0
fi
# Header: active status, session, thresholds, last poll
jq -r '
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
""
' "$STATE_FILE"
# Each agent: state, window, worktree/branch, truncated objective
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
if [ "$AGENT_COUNT" -eq 0 ]; then
echo " (no agents registered)"
else
jq -r '
.agents[] |
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
" \(.objective // "" | .[0:70])"
' "$STATE_FILE"
fi
echo ""
# Live pane_current_command for non-done agents
while IFS= read -r WINDOW; do
[ -z "$WINDOW" ] && continue
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
echo " $WINDOW live: $CMD"
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env bash
# verify-complete.sh — verify a PR task is truly done before marking the agent done
#
# Check order matters:
# 1. Checkpoints — did the agent do all required steps?
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
# 3. CI passing — no failures (agent must fix before done)
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
#
# Usage: verify-complete.sh WINDOW
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
set -euo pipefail
WINDOW="$1"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
# No PR number = cannot verify
if [ -z "$PR_NUMBER" ]; then
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
exit 1
fi
# --- Check 1: all required steps are checkpointed ---
MISSING=""
while IFS= read -r step; do
[ -z "$step" ] && continue
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
MISSING="$MISSING $step"
fi
done <<< "$STEPS"
if [ -n "$MISSING" ]; then
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
exit 1
fi
# Resolve repo for all GitHub checks below
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
fi
if [ -z "$REPO" ]; then
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
exit 0
fi
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
# --- Check 2: CI fully complete — no pending checks ---
# Pending checks MUST finish before we check threads/reviews:
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
if [ "$PENDING" -gt 0 ]; then
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
exit 1
fi
# --- Check 3: CI passing — no failures ---
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
if [ "$FAILING" -gt 0 ]; then
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
exit 1
fi
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
if [ -n "$LATEST_RUN_AT" ]; then
if date --version >/dev/null 2>&1; then
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
else
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
fi
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
exit 1
fi
fi
fi
OWNER=$(echo "$REPO" | cut -d/ -f1)
REPONAME=$(echo "$REPO" | cut -d/ -f2)
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
UNRESOLVED=$(gh api graphql -f query="
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
pullRequest(number: ${PR_NUMBER}) {
reviewThreads(first: 50) { nodes { isResolved } }
}
}
}
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
if [ "$UNRESOLVED" -gt 0 ]; then
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
exit 1
fi
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
# Stale reviews (pre-dating the fixing commits) should not block verification.
#
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
# treat that as NOT COMPLETE rather than silently passing.
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
# PR activity (bot comments, label changes) which would create false negatives.
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
--json commits,latestReviews 2>/dev/null) || {
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
exit 1
}
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
BLOCKING_CHANGES_REQUESTED=0
BLOCKING_REQUESTERS=""
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
if date --version >/dev/null 2>&1; then
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
else
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
fi
while IFS= read -r review; do
[ -z "$review" ] && continue
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
if [ -z "$REVIEW_DATE" ]; then
# No submission date — treat as fresh (conservative: blocks verification)
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
else
if date --version >/dev/null 2>&1; then
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
else
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
fi
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
# Review was submitted AFTER latest commit — still fresh, blocks verification
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
fi
# Review submitted BEFORE latest commit — stale, skip
fi
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
else
# No commit date or no changes_requested — check raw count as fallback
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
fi
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
exit 1
fi
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
exit 0

View File

@@ -29,30 +29,71 @@ gh pr view {N} --json body --jq '.body'
### 1. Inline review threads — GraphQL (primary source of actionable items)
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
> ⚠️ **WARNING — PAGINATE ALL PAGES BEFORE ADDRESSING ANYTHING**
>
> `reviewThreads(first: 100)` returns at most 100 threads per page. A PR with many review cycles can have 140+ threads across 2+ pages. **If you start addressing threads after fetching only page 1, you will miss all threads on subsequent pages and silently leave them unresolved.**
>
> PR #12636 had 142 total threads: page 1 returned 69 unresolved, page 2 had 42 more (111 total unresolved). An agent that stopped after page 1 addressed only 69 and falsely reported "done".
>
> **The rule: collect ALL thread IDs from ALL pages into a single list, then address them.**
**Step 1 — Fetch total count first:**
```bash
gh api graphql -f query='
{
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
pullRequest(number: {N}) {
reviewThreads(first: 100) {
pageInfo { hasNextPage endCursor }
nodes {
id
isResolved
path
comments(last: 1) {
nodes { databaseId body author { login } createdAt }
reviewThreads { totalCount }
}
}
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
```
If `totalCount > 100`, you have multiple pages. Fetch them all before doing anything else.
**Step 2 — Collect all unresolved thread IDs across all pages:**
```bash
# Accumulate all unresolved threads — loop until hasNextPage == false
CURSOR=""
ALL_THREADS="[]"
while true; do
AFTER=${CURSOR:+", after: \"$CURSOR\""}
PAGE=$(gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: {N}) {
reviewThreads(first: 100${AFTER}) {
pageInfo { hasNextPage endCursor }
nodes {
id
isResolved
path
line
comments(last: 1) {
nodes { databaseId body author { login } }
}
}
}
}
}
}
}'
}")
# Append unresolved nodes from this page
PAGE_THREADS=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)]')
ALL_THREADS=$(echo "$ALL_THREADS $PAGE_THREADS" | jq -s 'add')
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
[ "$HAS_NEXT" = "false" ] && break
done
echo "Total unresolved threads: $(echo "$ALL_THREADS" | jq 'length')"
echo "$ALL_THREADS" | jq '[.[] | {id, path, line, body: .comments.nodes[0].body[:200]}]'
```
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
**Step 3 — Address every thread in `ALL_THREADS`, then resolve.**
Only after this loop completes (all pages fetched, count confirmed) should you begin making fixes.
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
@@ -84,16 +125,43 @@ Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`gi
## For each unaddressed comment
Address comments **one at a time**: fix → commit → push → inline reply → next.
**CRITICAL: The only valid sequence is fix → commit → push → reply → resolve. Never resolve a thread without a real code commit.**
Resolving a thread via `resolveReviewThread` without an actual fix is the most common failure mode — it makes unresolved counts drop without any real change, producing a false "done" signal. If the issue was genuinely a false positive (no code change needed), reply explaining why and then resolve. Otherwise:
Address comments **one at a time**: fix → commit → push → inline reply → resolve.
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
Use a **markdown commit link** so GitHub renders it as a clickable reference. Always get the full SHA with `git rev-parse HEAD` **after** committing — never copy a SHA from a previous commit or hardcode one:
```bash
FULL_SHA=$(git rev-parse HEAD)
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies \
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): <description>"
```
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
### What counts as a valid resolution
Only two situations justify calling `resolveReviewThread`:
1. **Real code fix**: you changed the code, committed + pushed, and replied with the SHA. The commit diff must actually address the concern — not just touch the same file.
2. **Genuine false positive**: the reviewer's concern does not apply to this code, and you can give a specific technical reason (e.g. "Not applicable — `sdk_cwd` is pre-validated by `_make_sdk_cwd()` which applies normpath + prefix assertion before reaching this point").
**Anti-patterns that look resolved but aren't — never do these:**
- `"Accepted, tracked as follow-up"` — a deferral, not a fix. The concern is still open. Do not resolve.
- `"Acknowledged"` or `"Same as above"` — these are acknowledgements, not fixes. Do not resolve.
- `"Fixed in abc1234"` where `abc1234` is a commit that doesn't actually change the flagged line/logic — dishonest. Verify `git show abc1234 -- path/to/file` changes the right thing before posting.
- Resolving without replying — the reviewer never sees what happened.
When in doubt: if a code change is needed, make it. A deferred issue means the thread stays open until the follow-up PR is merged.
## Codecov coverage
@@ -141,6 +209,22 @@ Then commit and **push immediately** — never batch commits without pushing. Ea
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
## Coverage
Codecov enforces patch coverage on new/changed lines — new code you write must be tested. Before pushing, verify you haven't left new lines uncovered:
```bash
cd autogpt_platform/backend
poetry run pytest --cov=. --cov-report=term-missing {path/to/changed/module}
```
Look for lines marked `miss` — those are uncovered. Add tests for any new code you wrote as part of addressing comments.
**Rules:**
- New code you add should have tests
- Don't remove existing tests when fixing comments
- If a reviewer asks you to delete code, also delete its tests, but verify coverage hasn't dropped on remaining lines
## The loop
```text
@@ -230,3 +314,113 @@ git push
```
5. Restart the polling loop from the top — new commits reset CI status.
## GitHub abuse rate limits
Two distinct rate limits exist — they have different causes and recovery times:
| Error | HTTP code | Cause | Recovery |
|---|---|---|---|
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **23 minutes**. 60s is often not enough. |
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
**Recovery from secondary rate limit (403):**
1. Stop all API writes immediately
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
3. Resume with `sleep 3` between each call
4. If 403 persists after 2 min, wait another 2 min before retrying
Never batch all replies in a tight loop — always space them out.
## Parallel thread resolution
When a PR has more than 10 unresolved threads, addressing one commit per thread is slow. Use this strategy instead:
### Group by file, batch per commit
1. Sort `ALL_THREADS` by `path` — threads in the same file can share a single commit.
2. Fix all threads in one file → `git commit` → `git push` → reply to **all** those threads with the same SHA → resolve them all.
3. Move to the next file group and repeat.
This reduces N commits to (number of files touched), which is usually 35 instead of 1530.
### Posting replies concurrently (for large batches)
For truly independent thread groups (different files, no shared logic), you can post replies in parallel using background subshells — but always space out API writes:
```bash
# Post replies to a batch of threads concurrently, 3s apart
(
sleep 3
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID1}/replies \
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
) &
(
sleep 6
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID2}/replies \
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
) &
wait # wait for all background replies before resolving
```
Then resolve sequentially (GraphQL mutations):
```bash
for THREAD_ID in "$THREAD1" "$THREAD2" "$THREAD3"; do
gh api graphql -f query="mutation { resolveReviewThread(input: {threadId: \"${THREAD_ID}\"}) { thread { isResolved } } }"
sleep 3
done
```
**Always sleep 3s between individual API writes** — GitHub's secondary rate limit (403) triggers on bursts of >20 writes. Increase to `sleep 5` when posting more than 20 replies in a batch.
## Resolving threads via GraphQL
Use `resolveReviewThread` **only after** the commit is pushed and the reply is posted:
```bash
gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREAD_ID"}) { thread { isResolved } } }'
```
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
### Verify actual count before outputting ORCHESTRATOR:DONE
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
```bash
# Step 1: get total thread count
gh api graphql -f query='
{
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
pullRequest(number: {N}) {
reviewThreads { totalCount }
}
}
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
# Step 2: paginate all pages, count truly unresolved
CURSOR=""; UNRESOLVED=0
while true; do
AFTER=${CURSOR:+", after: \"$CURSOR\""}
PAGE=$(gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: {N}) {
reviewThreads(first: 100${AFTER}) {
pageInfo { hasNextPage endCursor }
nodes { isResolved }
}
}
}
}")
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
[ "$HAS_NEXT" = "false" ] && break
done
echo "Unresolved threads: $UNRESOLVED"
```
Only output `ORCHESTRATOR:DONE` after this loop reports 0.

View File

@@ -310,6 +310,28 @@ TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
```
### 3i. Disable onboarding for test user
The frontend redirects to `/onboarding` when the `VISIT_COPILOT` step is not in `completedSteps`.
Mark it complete via the backend API so every browser test lands on the real feature UI:
```bash
ONBOARDING_RESULT=$(curl -s --max-time 30 -X POST \
"http://localhost:8006/api/onboarding/step?step=VISIT_COPILOT" \
-H "Authorization: Bearer $TOKEN")
echo "Onboarding bypass: $ONBOARDING_RESULT"
# Verify it took effect
ONBOARDING_STATUS=$(curl -s --max-time 30 \
"http://localhost:8006/api/onboarding/completed" \
-H "Authorization: Bearer $TOKEN" | jq -r '.is_completed')
echo "Onboarding completed: $ONBOARDING_STATUS"
if [ "$ONBOARDING_STATUS" != "true" ]; then
echo "ERROR: onboarding bypass failed — browser tests will hit /onboarding instead of the target feature. Investigate before proceeding."
exit 1
fi
```
## Step 4: Run tests
### Service ports reference
@@ -547,6 +569,8 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
@@ -584,15 +608,27 @@ TREE_JSON+=']'
# Step 2: Create tree, commit, and branch ref
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
# Resolve parent commit so screenshots are chained, not orphan root commits
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
if [ -n "$PARENT_SHA" ]; then
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
-f "parents[]=$PARENT_SHA" \
--jq '.sha')
else
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
fi
gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
-X PATCH -f sha="$COMMIT_SHA" -f force=true
-X PATCH -f sha="$COMMIT_SHA" -F force=true
```
Then post the comment with **inline images AND explanations for each screenshot**:
@@ -658,6 +694,15 @@ INNEREOF
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
rm -f "$COMMENT_FILE"
# Verify the posted comment contains inline images — exit 1 if none found
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
exit 1
fi
echo "✓ Inline images verified in posted comment"
```
**The PR comment MUST include:**
@@ -667,6 +712,103 @@ rm -f "$COMMENT_FILE"
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
## Step 8: Evaluate and post a formal PR review
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
### Evaluation criteria
Re-read the PR description:
```bash
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
```
Score the run against each criterion:
| Criterion | Pass condition |
|-----------|---------------|
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
| **All scenarios pass** | No FAIL rows in the results table |
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
| **Before/after evidence** | Every state-changing API call has before/after values logged |
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
| **No regressions** | Existing core flows (login, agent create/run) still work |
### Decision logic
```
ALL criteria pass → APPROVE
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
```
### Post the review
```bash
REVIEW_FILE=$(mktemp)
# Count results
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
# List any coverage gaps found during evaluation (populate this array as you assess)
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
COVERAGE_GAPS=()
```
**If APPROVING** — all criteria met, zero failures, full coverage:
```bash
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — APPROVED
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
**Coverage:** All features described in the PR were exercised.
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
**Negative tests:** Failure paths tested for each feature.
No regressions observed on core flows.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
echo "✅ PR approved"
```
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
```bash
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — Changes Requested
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
### Required before merge
${FAIL_LIST}
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
Please fix the above and re-run the E2E tests.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
echo "❌ Changes requested"
```
```bash
rm -f "$REVIEW_FILE"
```
**Rules:**
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
- Never request changes for issues already fixed in this run
## Fix mode (--fix flag)
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.

View File

@@ -0,0 +1,84 @@
import logging
from datetime import datetime
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Query, Security
from pydantic import BaseModel
from backend.data.platform_cost import (
CostLogRow,
PlatformCostDashboard,
get_platform_cost_dashboard,
get_platform_cost_logs,
)
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/platform-costs",
tags=["platform-cost", "admin"],
dependencies=[Security(requires_admin_user)],
)
class PlatformCostLogsResponse(BaseModel):
logs: list[CostLogRow]
pagination: Pagination
@router.get(
"/dashboard",
response_model=PlatformCostDashboard,
summary="Get Platform Cost Dashboard",
)
async def get_cost_dashboard(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
return await get_platform_cost_dashboard(
start=start,
end=end,
provider=provider,
user_id=user_id,
)
@router.get(
"/logs",
response_model=PlatformCostLogsResponse,
summary="Get Platform Cost Logs",
)
async def get_cost_logs(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
start=start,
end=end,
provider=provider,
user_id=user_id,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
logs=logs,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -0,0 +1,192 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.platform_cost import PlatformCostDashboard
from .platform_cost_routes import router as platform_cost_router
app = fastapi.FastAPI()
app.include_router(platform_cost_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_get_dashboard_success(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
response = client.get("/platform-costs/dashboard")
assert response.status_code == 200
data = response.json()
assert "by_provider" in data
assert "by_user" in data
assert data["total_cost_microdollars"] == 0
def test_get_logs_success(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get("/platform-costs/logs")
assert response.status_code == 200
data = response.json()
assert data["logs"] == []
assert data["pagination"]["total_items"] == 0
def test_get_dashboard_with_filters(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mock_dashboard = AsyncMock(return_value=real_dashboard)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
mock_dashboard,
)
response = client.get(
"/platform-costs/dashboard",
params={
"start": "2026-01-01T00:00:00",
"end": "2026-04-01T00:00:00",
"provider": "openai",
"user_id": "test-user-123",
},
)
assert response.status_code == 200
mock_dashboard.assert_called_once()
call_kwargs = mock_dashboard.call_args.kwargs
assert call_kwargs["provider"] == "openai"
assert call_kwargs["user_id"] == "test-user-123"
assert call_kwargs["start"] is not None
assert call_kwargs["end"] is not None
def test_get_logs_with_pagination(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get(
"/platform-costs/logs",
params={"page": 2, "page_size": 25, "provider": "anthropic"},
)
assert response.status_code == 200
data = response.json()
assert data["pagination"]["current_page"] == 2
assert data["pagination"]["page_size"] == 25
def test_get_dashboard_requires_admin() -> None:
import fastapi
from fastapi import HTTPException
def reject_jwt(request: fastapi.Request):
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_jwt_payload] = reject_jwt
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 401
response = client.get("/platform-costs/logs")
assert response.status_code == 401
finally:
app.dependency_overrides.clear()
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 403
response = client.get("/platform-costs/logs")
assert response.status_code == 403
finally:
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
def test_get_logs_invalid_page_size_too_large() -> None:
"""page_size > 200 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 201})
assert response.status_code == 422
def test_get_logs_invalid_page_size_zero() -> None:
"""page_size = 0 (below ge=1) must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 0})
assert response.status_code == 422
def test_get_logs_invalid_page_negative() -> None:
"""page < 1 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page": 0})
assert response.status_code == 422
def test_get_dashboard_invalid_date_format() -> None:
"""Malformed start date must be rejected with 422."""
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
assert response.status_code == 422
def test_get_dashboard_repeated_requests(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Repeated requests to the dashboard route both return 200."""
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=42,
total_requests=1,
total_users=1,
)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
r1 = client.get("/platform-costs/dashboard")
r2 = client.get("/platform-costs/dashboard")
assert r1.status_code == 200
assert r2.status_code == 200
assert r1.json()["total_cost_microdollars"] == 42
assert r2.json()["total_cost_microdollars"] == 42

View File

@@ -16,6 +16,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
@@ -155,6 +156,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
@@ -394,60 +397,78 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
"""
session = await get_chat_session(session_id, user_id)
if not session:
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in page.messages]
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
# Only check active stream on initial load (not on "load more" requests)
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=0,
total_completion_tokens=0,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
total_completion = sum(u.completion_tokens for u in page.session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=session.metadata,
metadata=page.session.metadata,
)

View File

@@ -12,7 +12,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from pydantic import BaseModel, Field
from backend.data.workspace import (
WorkspaceFile,
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
file_count: int
class WorkspaceFileItem(BaseModel):
id: str
name: str
path: str
mime_type: str
size_bytes: int
metadata: dict = Field(default_factory=dict)
created_at: str
class ListFilesResponse(BaseModel):
files: list[WorkspaceFileItem]
offset: int = 0
has_more: bool = False
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
operation_id="getWorkspaceDownloadFileById",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -158,6 +175,7 @@ async def download_file(
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
operation_id="deleteWorkspaceFile",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -183,6 +201,7 @@ async def delete_workspace_file(
@router.post(
"/files/upload",
summary="Upload file to workspace",
operation_id="uploadWorkspaceFile",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -196,6 +215,9 @@ async def upload_file(
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
# Empty-string session_id drops session scoping; normalize to None.
session_id = session_id or None
config = Config()
# Sanitize filename — strip any directory components
@@ -250,16 +272,27 @@ async def upload_file(
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# write_file raises ValueError for both path-conflict and size-limit
# cases; map each to its correct HTTP status.
message = str(e)
if message.startswith("File too large"):
raise fastapi.HTTPException(status_code=413, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=message) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
try:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
except Exception as e:
logger.warning(
f"Failed to soft-delete over-quota file {workspace_file.id} "
f"in workspace {workspace.id}: {e}"
)
raise fastapi.HTTPException(
status_code=413,
detail={
@@ -281,6 +314,7 @@ async def upload_file(
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
operation_id="getWorkspaceStorageUsage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -301,3 +335,57 @@ async def get_storage_usage(
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)
@router.get(
"/files",
summary="List workspace files",
operation_id="listWorkspaceFiles",
)
async def list_workspace_files(
user_id: Annotated[str, fastapi.Security(get_user_id)],
session_id: str | None = Query(default=None),
limit: int = Query(default=200, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
) -> ListFilesResponse:
"""
List files in the user's workspace.
When session_id is provided, only files for that session are returned.
Otherwise, all files across sessions are listed. Results are paginated
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
"""
workspace = await get_or_create_workspace(user_id)
# Treat empty-string session_id the same as omitted — an empty value
# would otherwise silently list files across every session instead of
# scoping to one.
session_id = session_id or None
manager = WorkspaceManager(user_id, workspace.id, session_id)
include_all = session_id is None
# Fetch one extra to compute has_more without a separate count query.
files = await manager.list_files(
limit=limit + 1,
offset=offset,
include_all_sessions=include_all,
)
has_more = len(files) > limit
page = files[:limit]
return ListFilesResponse(
files=[
WorkspaceFileItem(
id=f.id,
name=f.name,
path=f.path,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
metadata=f.metadata or {},
created_at=f.created_at.isoformat(),
)
for f in page
],
offset=offset,
has_more=has_more,
)

View File

@@ -1,48 +1,28 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
from backend.api.features.workspace.routes import router
from backend.data.workspace import Workspace, WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
app.include_router(router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
"""Mirror the production ValueError → 400 mapping from the REST app."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
return Workspace(
id="ws-001",
user_id=user_id,
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
)
def _make_file(**overrides) -> WorkspaceFile:
defaults = {
"id": "file-001",
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "test.txt",
"path": "/test.txt",
"storage_path": "local://test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _make_file_mock(**overrides) -> MagicMock:
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
defaults = {
"id": "file-001",
"name": "test.txt",
"path": "/test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"metadata": {},
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
}
defaults.update(overrides)
mock = MagicMock(spec=WorkspaceFile)
for k, v in defaults.items():
setattr(mock, k, v)
return mock
# -- list_workspace_files tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
files = [
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
data = response.json()
assert len(data["files"]) == 2
assert data["has_more"] is False
assert data["offset"] == 0
assert data["files"][0]["id"] == "f1"
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_scopes_to_session_when_provided(
mock_manager_cls, mock_get_workspace, test_user_id
):
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?session_id=sess-123")
assert response.status_code == 200
data = response.json()
assert data["files"] == []
assert data["has_more"] is False
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=False
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_null_metadata_coerced_to_empty_dict(
mock_manager_cls, mock_get_workspace
):
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
assert response.json()["files"][0]["metadata"] == {}
# -- upload_file metadata tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_passes_user_upload_origin_metadata(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
written = _make_file(id="new-file", name="doc.pdf")
mock_instance = AsyncMock()
mock_instance.write_file.return_value = written
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
)
assert response.status_code == 200
mock_instance.write_file.assert_called_once()
call_kwargs = mock_instance.write_file.call_args
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_returns_409_on_file_conflict(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
mock_instance = AsyncMock()
mock_instance.write_file.side_effect = ValueError("File already exists at path")
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("dup.txt", b"content", "text/plain")},
)
assert response.status_code == 409
assert "already exists" in response.json()["detail"]
# -- Restored upload/download/delete security + invariant tests --
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
_MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-001",
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
name="hello.txt",
path="/sessions/sess-1/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
def test_upload_happy_path(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
def test_upload_exceeds_max_file_size(mocker):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
def test_upload_storage_quota_exceeded(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
def test_upload_post_write_quota_race(mocker):
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
side_effect=[0, 600 * 1024 * 1024],
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
def test_upload_any_extension(mocker):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
def test_upload_blocked_by_virus_scan(mocker):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
def test_upload_file_without_extension(mocker):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
def test_upload_strips_path_components(mocker):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
def test_download_file_not_found(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
assert response.status_code == 404
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
def test_delete_file_success(mocker):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
def test_delete_file_not_found(mocker):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
def test_delete_file_no_workspace(mocker):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text
def test_upload_write_file_too_large_returns_413(mocker):
"""write_file raises ValueError("File too large: …") → must map to 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 413
assert "File too large" in response.text
def test_upload_write_file_conflict_returns_409(mocker):
"""Non-'File too large' ValueErrors from write_file stay as 409."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 409
assert "already exists" in response.text
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_true_when_limit_exceeded(
mock_manager_cls, mock_get_workspace
):
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
mock_get_workspace.return_value = _make_workspace()
# Backend was asked for limit+1=3, and returned exactly 3 items.
files = [
_make_file(id="f1", name="a.txt"),
_make_file(id="f2", name="b.txt"),
_make_file(id="f3", name="c.txt"),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is True
assert len(data["files"]) == 2
assert data["files"][0]["id"] == "f1"
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=3, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_false_when_exactly_page_size(
mock_manager_cls, mock_get_workspace
):
"""Exactly `limit` rows means we're on the last page — has_more=False."""
mock_get_workspace.return_value = _make_workspace()
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is False
assert len(data["files"]) == 2
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?offset=50&limit=10")
assert response.status_code == 200
assert response.json()["offset"] == 50
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)

View File

@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
@@ -329,6 +330,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.admin.platform_cost_routes.router,
tags=["v2", "admin"],
prefix="/api/admin",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class SearchOrganizationsBlock(Block):
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
organizations = await self.search_organizations(query, credentials)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(organizations)), provider_cost_type="items"
)
)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class SearchPeopleBlock(Block):
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
*(enrich_or_fallback(person) for person in people)
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(people)), provider_cost_type="items"
)
)
yield "people", people

View File

@@ -0,0 +1,712 @@
"""Unit tests for merge_stats cost tracking in individual blocks.
Covers the exa code_context, exa contents, and apollo organization blocks
to verify provider cost is correctly extracted and reported.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, NodeExecutionStats
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TEST_EXA_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_EXA_CREDENTIALS_INPUT = {
"provider": TEST_EXA_CREDENTIALS.provider,
"id": TEST_EXA_CREDENTIALS.id,
"type": TEST_EXA_CREDENTIALS.type,
"title": TEST_EXA_CREDENTIALS.title,
}
# ---------------------------------------------------------------------------
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
# ---------------------------------------------------------------------------
class TestExaCodeContextBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_float_cost(self):
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-1",
"query": "how to use hooks",
"response": "Here are some examples...",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="how to use hooks",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_dollars_does_not_raise(self):
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-2",
"query": "query",
"response": "response",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
merge_calls: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert merge_calls == []
@pytest.mark.asyncio
async def test_zero_cost_is_tracked(self):
"""A zero cost_dollars string '0.0' should still be recorded."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-3",
"query": "query",
"response": "...",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
# ---------------------------------------------------------------------------
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
# ---------------------------------------------------------------------------
class TestExaContentsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_cost_dollars_total(self):
"""provider_cost equals response.cost_dollars.total when present."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
cost_dollars = CostDollars(total=0.012)
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = cost_dollars
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_merge_stats_when_cost_dollars_absent(self):
"""When response.cost_dollars is None, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = None
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
# ---------------------------------------------------------------------------
class TestSearchOrganizationsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_org_count(self):
"""provider_cost == number of returned organizations, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Organization
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=fake_orgs,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=APOLLO_CREDS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(3.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_org_list_tracks_zero(self):
"""An empty organization list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=APOLLO_CREDS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# JinaEmbeddingBlock — token count from usage.total_tokens
# ---------------------------------------------------------------------------
class TestJinaEmbeddingBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_token_count(self):
"""provider token count is recorded when API returns usage.total_tokens."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
"usage": {"total_tokens": 42},
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello world"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].input_token_count == 42
@pytest.mark.asyncio
async def test_no_merge_stats_when_usage_absent(self):
"""When API response omits usage field, merge_stats is not called."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# UnrealTextToSpeechBlock — character count from input text length
# ---------------------------------------------------------------------------
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
test_text = "Hello, world!"
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text=test_text,
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text="",
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------
# GoogleMapsSearchBlock — item count from search_places results
# ---------------------------------------------------------------------------
class TestGoogleMapsSearchBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_place_count(self):
"""provider_cost equals number of returned places, type == 'items'."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=fake_places,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="coffee shops",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 4.0
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_results_tracks_zero(self):
"""Zero places returned results in provider_cost=0.0."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="nothing here",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SmartLeadAddLeadsBlock — item count from lead_list length
# ---------------------------------------------------------------------------
class TestSmartLeadAddLeadsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_lead_count(self):
"""provider_cost equals number of leads uploaded, type == 'items'."""
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
)
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
from backend.blocks.smartlead.models import (
AddLeadsToCampaignResponse,
LeadInput,
)
block = AddLeadToCampaignBlock()
fake_leads = [
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
]
fake_response = AddLeadsToCampaignResponse(
ok=True,
upload_count=2,
total_leads=2,
block_count=0,
duplicate_count=0,
invalid_email_count=0,
invalid_emails=[],
already_added_to_campaign=0,
unsubscribed_leads=[],
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
bounce_count=0,
)
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
AddLeadToCampaignBlock,
"add_leads_to_campaign",
new_callable=AsyncMock,
return_value=fake_response,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = AddLeadToCampaignBlock.Input(
campaign_id=123,
lead_list=fake_leads,
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=SL_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 2.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SearchPeopleBlock — item count from people list length
# ---------------------------------------------------------------------------
class TestSearchPeopleBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_people_count(self):
"""provider_cost equals number of returned people, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Contact
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=fake_people,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(5.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_people_list_tracks_zero(self):
"""An empty people list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"

View File

@@ -9,6 +9,7 @@ from typing import Union
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
yield "cost_dollars", context.cost_dollars
yield "search_time", context.search_time
yield "output_tokens", context.output_tokens
# Parse cost_dollars (API returns as string, e.g. "0.005")
try:
cost_usd = float(context.cost_dollars)
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
except (ValueError, TypeError):
pass

View File

@@ -4,6 +4,7 @@ from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -0,0 +1,575 @@
"""Tests for cost tracking in Exa blocks.
Covers the cost_dollars → provider_cost → merge_stats path for both
ExaContentsBlock and ExaCodeContextBlock.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import NodeExecutionStats
class TestExaCodeContextCostTracking:
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
@pytest.mark.asyncio
async def test_valid_cost_string_is_parsed_and_merged(self):
"""A numeric cost string like '0.005' is merged as provider_cost."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-1",
"query": "test query",
"response": "some code",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
assert any(k == "cost_dollars" for k, _ in outputs)
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_string_does_not_raise(self):
"""A non-numeric cost_dollars value is swallowed silently."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-2",
"query": "test",
"response": "code",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
# No merge_stats call because float() raised ValueError
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_string_is_merged(self):
"""'0.0' is a valid cost — should still be tracked."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-3",
"query": "free query",
"response": "result",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
async for _ in block.run(
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaContentsCostTracking:
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_dollars_is_merged(self):
"""A total of 0.0 (free tier) should still be merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaSearchCostTracking:
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.008)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
class TestExaSimilarCostTracking:
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-1"
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.015)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-2"
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
# ---------------------------------------------------------------------------
# ExaCreateResearchBlock — cost_dollars from completed poll response
# ---------------------------------------------------------------------------
COMPLETED_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "completed",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
"finishedAt": 1700000060000,
"costDollars": {
"total": 0.05,
"numSearches": 3,
"numPages": 10,
"reasoningTokens": 500,
},
"output": {"content": "Research findings...", "parsed": None},
}
PENDING_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "pending",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
}
class TestExaCreateResearchBlockCostTracking:
"""ExaCreateResearchBlock merges cost from completed poll response."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total when poll returns completed."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed response has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaGetResearchBlock — cost_dollars from single GET response
# ---------------------------------------------------------------------------
class TestExaGetResearchBlockCostTracking:
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_from_completed_research(self):
"""merge_stats called with provider_cost=total when research has costDollars."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
get_resp = MagicMock()
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
get_resp = MagicMock()
get_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaWaitForResearchBlock — cost_dollars from polling response
# ---------------------------------------------------------------------------
class TestExaWaitForResearchBlockCostTracking:
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total once polling returns completed."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(
provider_cost=research.cost_dollars.total
)
)
return
await asyncio.sleep(check_interval)
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
yield "cost_searches", research.cost_dollars.num_searches
yield "cost_pages", research.cost_dollars.num_pages
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
yield "error_message", research.error
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
return

View File

@@ -4,6 +4,7 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -3,6 +3,7 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -14,6 +14,7 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
input_data.radius,
input_data.max_results,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(places)), provider_cost_type="items"
)
)
for place in places:
yield "place", place

View File

@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.request import Requests
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
}
data = {"input": input_data.texts, "model": input_data.model}
response = await Requests().post(url, headers=headers, json=data)
embeddings = [e["embedding"] for e in response.json()["data"]]
resp_json = response.json()
embeddings = [e["embedding"] for e in resp_json["data"]]
usage = resp_json.get("usage", {})
if usage.get("total_tokens"):
self.merge_stats(
NodeExecutionStats(
input_token_count=usage.get("total_tokens", 0),
)
)
yield "embeddings", embeddings

View File

@@ -1,6 +1,7 @@
# This file contains a lot of prompt block strings that would trigger "line too long"
# flake8: noqa: E501
import logging
import math
import re
import secrets
from abc import ABC
@@ -13,6 +14,7 @@ import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
@@ -737,6 +739,7 @@ class LLMResponse(BaseModel):
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
provider_cost: float | None = None
def convert_openai_tool_fmt_to_anthropic(
@@ -771,6 +774,35 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
OpenRouter returns the per-request USD cost in a response header. The
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
attribute. We use try/except AttributeError so that if the SDK ever drops
or renames that attribute, the warning is visible in logs rather than
silently degrading to no cost tracking.
"""
try:
raw_resp = response._response # type: ignore[attr-defined]
except AttributeError:
logger.warning(
"OpenAI SDK response missing _response attribute"
" — OpenRouter cost tracking unavailable"
)
return None
try:
cost_header = raw_resp.headers.get("x-total-cost")
if not cost_header:
return None
cost = float(cost_header)
if not math.isfinite(cost) or cost < 0:
return None
return cost
except (ValueError, TypeError, AttributeError):
return None
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
@@ -1103,6 +1135,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
provider_cost=extract_openrouter_cost(response),
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -1410,6 +1443,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
last_attempt_cost: float | None = None
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
@@ -1427,12 +1461,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
# Merge token counts for every attempt (each call costs tokens).
# provider_cost (actual USD) is tracked separately and only merged
# on success to avoid double-counting across retries.
token_stats = NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
self.merge_stats(token_stats)
last_attempt_cost = llm_response.provider_cost
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
@@ -1501,6 +1538,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", response_obj
@@ -1521,6 +1559,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", {"response": response_text}

View File

@@ -251,8 +251,13 @@ def _convert_raw_response_to_dict(
# Already a dict (from tests or some providers)
return raw_response
elif _is_responses_api_object(raw_response):
# OpenAI Responses API: extract individual output items
items = [json.to_dict(item) for item in raw_response.output]
# OpenAI Responses API: extract individual output items.
# Strip 'status' — it's a response-only field that OpenAI rejects
# when the item is sent back as input on the next API call.
items = [
{k: v for k, v in json.to_dict(item).items() if k != "status"}
for item in raw_response.output
]
return items if items else [{"role": "assistant", "content": ""}]
else:
# Chat Completions / Anthropic return message objects

View File

@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
SaveSequencesResponse,
Sequence,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class CreateCampaignBlock(Block):
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
response = await self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.lead_list)),
provider_cost_type="items",
)
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count

View File

@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
assert block.execution_stats.llm_retry_count == 1
@pytest.mark.asyncio
async def test_retry_cost_uses_last_attempt_only(self):
"""provider_cost is only merged from the final successful attempt.
Intermediate retry costs are intentionally dropped to avoid
double-counting: the cost of failed attempts is captured in
last_attempt_cost only when the loop eventually succeeds.
"""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First attempt: fails validation, returns cost $0.01
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
provider_cost=0.01,
)
# Second attempt: succeeds, returns cost $0.02
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
tool_calls=None,
prompt_tokens=20,
completion_tokens=10,
reasoning=None,
provider_cost=0.02,
)
block.llm_call = mock_llm_call # type: ignore
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
with patch("secrets.token_hex", return_value="test123456"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
# Only the final successful attempt's cost is merged
assert block.execution_stats.provider_cost == pytest.approx(0.02)
# Tokens from both attempts accumulate
assert block.execution_stats.input_token_count == 30
assert block.execution_stats.output_token_count == 15
@pytest.mark.asyncio
async def test_ai_text_summarizer_multiple_chunks(self):
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
@@ -987,3 +1047,67 @@ class TestLlmModelMissing:
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)
class TestExtractOpenRouterCost:
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
def _mk_response(self, headers: dict | None):
response = MagicMock()
if headers is None:
response._response = None
else:
raw = MagicMock()
raw.headers = headers
response._response = raw
return response
def test_extracts_numeric_cost(self):
response = self._mk_response({"x-total-cost": "0.0042"})
assert llm.extract_openrouter_cost(response) == 0.0042
def test_returns_none_when_header_missing(self):
response = self._mk_response({})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_empty_string(self):
response = self._mk_response({"x-total-cost": ""})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_non_numeric(self):
response = self._mk_response({"x-total-cost": "not-a-number"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_no_response_attr(self):
response = MagicMock(spec=[]) # no _response attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_is_none(self):
response = self._mk_response(None)
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_has_no_headers(self):
response = MagicMock()
response._response = MagicMock(spec=[]) # no headers attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_zero_for_zero_cost(self):
"""Zero-cost is a valid value (free tier) and must not become None."""
response = self._mk_response({"x-total-cost": "0"})
assert llm.extract_openrouter_cost(response) == 0.0
def test_returns_none_for_inf(self):
response = self._mk_response({"x-total-cost": "inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_negative_inf(self):
response = self._mk_response({"x-total-cost": "-inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_nan(self):
response = self._mk_response({"x-total-cost": "nan"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_negative_cost(self):
response = self._mk_response({"x-total-cost": "-0.005"})
assert llm.extract_openrouter_cost(response) is None

View File

@@ -211,6 +211,30 @@ class TestConvertRawResponseToDict:
# A single dict is wrong — there are two distinct items
pytest.fail("Expected a list of output items, got a single dict")
def test_responses_api_strips_status_from_function_call(self):
"""Responses API function_call items have a 'status' field that OpenAI
rejects when sent back as input ('Unknown parameter: input[N].status').
It must be stripped before the item is stored in conversation history."""
resp = _MockResponse(
output=[_MockFunctionCall("my_tool", '{"x": 1}', call_id="call_xyz")]
)
result = _convert_raw_response_to_dict(resp)
assert isinstance(result, list)
for item in result:
assert (
"status" not in item
), f"'status' must be stripped from Responses API items: {item}"
def test_responses_api_strips_status_from_message(self):
"""Responses API message items also carry 'status'; it must be stripped."""
resp = _MockResponse(output=[_MockOutputMessage("Hello")])
result = _convert_raw_response_to_dict(resp)
assert isinstance(result, list)
for item in result:
assert (
"status" not in item
), f"'status' must be stripped from Responses API items: {item}"
# ───────────────────────────────────────────────────────────────────────────
# _get_tool_requests (lines 61-86)

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
input_data.text,
input_data.voice_id,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.text)),
provider_cost_type="characters",
)
)
yield "mp3_url", api_response["OutputUri"]

View File

@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
import asyncio
import base64
import logging
import math
import os
import re
import shutil
@@ -22,6 +23,7 @@ from typing import TYPE_CHECKING, Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from opentelemetry import trace as otel_trace
from backend.copilot.config import CopilotMode
from backend.copilot.context import get_workspace_manager, set_execution_context
@@ -30,7 +32,6 @@ from backend.copilot.model import (
ChatSession,
get_chat_session,
maybe_append_user_message,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
@@ -51,8 +52,8 @@ from backend.copilot.response_model import (
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
_get_openai_client,
_update_title_async,
config,
)
from backend.copilot.token_tracking import persist_and_record_usage
@@ -334,6 +335,7 @@ class _BaselineStreamState:
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
cost_usd: float | None = None
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
session_messages: list[ChatMessage] = field(default_factory=list)
@@ -354,6 +356,7 @@ async def _baseline_llm_caller(
state.thinking_stripper = _ThinkingStripper()
round_text = ""
response = None # initialized before try so finally block can access it
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
@@ -430,6 +433,20 @@ async def _baseline_llm_caller(
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Extract OpenRouter cost from response headers (in finally so we
# capture cost even when the stream errors mid-way — we already paid).
# Accumulate across multi-round tool-calling turns.
try:
# Access undocumented _response attribute — same pattern as
# extract_openrouter_cost() in blocks/llm.py.
cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined]
if cost_header:
cost = float(cost_header)
if math.isfinite(cost) and cost >= 0:
state.cost_usd = (state.cost_usd or 0.0) + cost
except (AttributeError, ValueError):
pass
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
@@ -686,18 +703,6 @@ def _baseline_conversation_updater(
)
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
async def _compress_session_messages(
messages: list[ChatMessage],
model: str,
@@ -1183,8 +1188,22 @@ async def stream_chat_completion_baseline(
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Close Langfuse trace context
# Set cost attributes on OTEL span before closing
if _trace_ctx is not None:
try:
span = otel_trace.get_current_span()
if span and span.is_recording():
span.set_attribute(
"gen_ai.usage.prompt_tokens", state.turn_prompt_tokens
)
span.set_attribute(
"gen_ai.usage.completion_tokens",
state.turn_completion_tokens,
)
if state.cost_usd is not None:
span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd)
except Exception:
logger.debug("[Baseline] Failed to set OTEL cost attributes")
try:
_trace_ctx.__exit__(None, None, None)
except Exception:
@@ -1226,6 +1245,8 @@ async def stream_chat_completion_baseline(
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
log_prefix="[Baseline]",
cost_usd=state.cost_usd,
model=active_model,
)
# Persist structured tool-call history (assistant + tool messages)

View File

@@ -4,7 +4,7 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai.types.chat import ChatCompletionToolParam
@@ -631,3 +631,200 @@ class TestPrepareBaselineAttachments:
assert hint == ""
assert blocks == []
class TestBaselineCostExtraction:
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
@pytest.mark.asyncio
async def test_cost_usd_extracted_from_response_header(self):
"""state.cost_usd is set from x-total-cost header when present."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
# Build a mock raw httpx response with the cost header
mock_raw_response = MagicMock()
mock_raw_response.headers = {"x-total-cost": "0.0123"}
# Build a mock async streaming response that yields no chunks but has
# a _response attribute pointing to the mock httpx response
mock_stream_response = MagicMock()
mock_stream_response._response = mock_raw_response
async def empty_aiter():
return
yield # make it an async generator
mock_stream_response.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=mock_stream_response
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.0123)
@pytest.mark.asyncio
async def test_cost_usd_accumulates_across_calls(self):
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
def make_stream_mock(cost: str) -> MagicMock:
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": cost}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
return mock_stream
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "first"}],
tools=[],
state=state,
)
await _baseline_llm_caller(
messages=[{"role": "user", "content": "second"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.03)
@pytest.mark.asyncio
async def test_no_cost_when_header_absent(self):
"""state.cost_usd remains None when response has no x-total-cost header."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd is None
@pytest.mark.asyncio
async def test_cost_extracted_even_when_stream_raises(self):
"""cost_usd is captured in the finally block even when streaming fails."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": "0.005"}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def failing_aiter():
raise RuntimeError("stream error")
yield # make it an async generator
mock_stream.__aiter__ = lambda self: failing_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with (
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
),
pytest.raises(RuntimeError, match="stream error"),
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_no_cost_when_api_call_raises_before_stream(self):
"""finally block is safe when response is None (API call failed before yielding)."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=RuntimeError("connection refused")
)
with (
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
),
pytest.raises(RuntimeError, match="connection refused"),
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
# response was never assigned so cost extraction must not raise
assert state.cost_usd is None

View File

@@ -14,6 +14,7 @@ from prisma.types import (
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
@@ -30,6 +31,15 @@ from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
session: ChatSessionInfo
async def get_chat_session(session_id: str) -> ChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
@@ -39,6 +49,116 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
return ChatSession.from_db(session) if session else None
async def get_chat_session_metadata(session_id: str) -> ChatSessionInfo | None:
"""Get chat session metadata (without messages) for ownership validation."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
)
return ChatSessionInfo.from_db(session) if session else None
async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session, newest first.
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
if before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
session = await PrismaChatSession.prisma().find_first(
where=session_where,
include={"Messages": msg_include},
)
if session is None:
return None
session_info = ChatSessionInfo.from_db(session)
results = list(session.Messages) if session.Messages else []
has_more = len(results) > limit
results = results[:limit]
# Reverse to ascending order
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
session=session_info,
)
async def create_chat_session(
session_id: str,
user_id: str,

View File

@@ -1,7 +1,341 @@
import pytest
"""Unit tests for copilot.db — paginated message queries."""
from .db import set_turn_duration
from .model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from backend.copilot.db import (
PaginatedMessages,
get_chat_messages_paginated,
set_turn_duration,
)
from backend.copilot.model import ChatMessage as CopilotChatMessage
from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session
def _make_msg(
sequence: int,
role: str = "assistant",
content: str | None = "hello",
tool_calls: Any = None,
) -> PrismaChatMessage:
"""Build a minimal PrismaChatMessage for testing."""
return PrismaChatMessage(
id=f"msg-{sequence}",
createdAt=datetime.now(UTC),
sessionId="sess-1",
role=role,
content=content,
sequence=sequence,
toolCalls=tool_calls,
name=None,
toolCallId=None,
refusal=None,
functionCall=None,
)
def _make_session(
session_id: str = "sess-1",
user_id: str = "user-1",
messages: list[PrismaChatMessage] | None = None,
) -> PrismaChatSession:
"""Build a minimal PrismaChatSession for testing."""
now = datetime.now(UTC)
session = PrismaChatSession.model_construct(
id=session_id,
createdAt=now,
updatedAt=now,
userId=user_id,
credentials={},
successfulAgentRuns={},
successfulAgentSchedules={},
totalPromptTokens=0,
totalCompletionTokens=0,
title=None,
metadata={},
Messages=messages or [],
)
return session
SESSION_ID = "sess-1"
@pytest.fixture()
def mock_db():
"""Patch ChatSession.prisma().find_first and ChatMessage.prisma().find_many.
find_first is used for the main query (session + included messages).
find_many is used only for boundary expansion queries.
"""
with (
patch.object(PrismaChatSession, "prisma") as mock_session_prisma,
patch.object(PrismaChatMessage, "prisma") as mock_msg_prisma,
):
find_first = AsyncMock()
mock_session_prisma.return_value.find_first = find_first
find_many = AsyncMock(return_value=[])
mock_msg_prisma.return_value.find_many = find_many
yield find_first, find_many
# ---------- Basic pagination ----------
@pytest.mark.asyncio
async def test_basic_page_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Messages are returned in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert isinstance(page, PaginatedMessages)
assert [m.sequence for m in page.messages] == [1, 2, 3]
assert page.has_more is False
assert page.oldest_sequence == 1
@pytest.mark.asyncio
async def test_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more is True when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.has_more is True
assert len(page.messages) == 2
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_empty_session_returns_no_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.messages == []
assert page.has_more is False
assert page.oldest_sequence is None
@pytest.mark.asyncio
async def test_before_sequence_filters_correctly(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""before_sequence is passed as a where filter inside the Messages include."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(2), _make_msg(1)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, before_sequence=5)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["where"] == {"sequence": {"lt": 5}}
@pytest.mark.asyncio
async def test_no_where_on_messages_without_before_sequence(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Without before_sequence, the Messages include has no where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""user_id adds a userId filter to the session-level where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50, user_id="user-abc")
call_kwargs = find_first.call_args
where = call_kwargs.kwargs.get("where") or call_kwargs[1].get("where")
assert where["userId"] == "user-abc"
@pytest.mark.asyncio
async def test_session_not_found_returns_none(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Returns None when session doesn't exist or user doesn't own it."""
find_first, _ = mock_db
find_first.return_value = None
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is None
@pytest.mark.asyncio
async def test_session_info_included_in_result(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""PaginatedMessages includes session metadata."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.session.session_id == SESSION_ID
# ---------- Backward boundary expansion ----------
@pytest.mark.asyncio
async def test_boundary_expansion_includes_assistant(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When page starts with a tool message, expand backward to include
the owning assistant message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.return_value = [_make_msg(3, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [3, 4, 5]
assert page.messages[0].role == "assistant"
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_boundary_expansion_includes_multiple_tool_msgs(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Boundary expansion scans past consecutive tool messages to find
the owning assistant."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(7, role="tool")],
)
find_many.return_value = [
_make_msg(6, role="tool"),
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [4, 5, 6, 7]
assert page.messages[0].role == "assistant"
@pytest.mark.asyncio
async def test_boundary_expansion_sets_has_more_when_not_at_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""After boundary expansion, has_more=True if expanded msgs aren't at seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="tool")],
)
find_many.return_value = [_make_msg(2, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is True
@pytest.mark.asyncio
async def test_boundary_expansion_no_has_more_at_conversation_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more stays False when boundary expansion reaches seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool")],
)
find_many.return_value = [_make_msg(0, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is False
assert page.oldest_sequence == 0
@pytest.mark.asyncio
async def test_no_boundary_expansion_when_first_msg_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No boundary expansion when the first message is not a tool message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="user"), _make_msg(2, role="assistant")],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert find_many.call_count == 0
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_boundary_expansion_warns_when_no_owner_found(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When boundary scan doesn't find a non-tool message, a warning is logged
and the boundary messages are still included."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
)
find_many.return_value = [_make_msg(i, role="tool") for i in range(9, -1, -1)]
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
mock_logger.warning.assert_called_once()
assert page is not None
assert page.messages[0].role == "tool"
assert len(page.messages) > 1
# ---------- Turn duration (integration tests) ----------
@pytest.mark.asyncio(loop_scope="session")
@@ -15,8 +349,8 @@ async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_us
"""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi there"),
CopilotChatMessage(role="user", content="hello"),
CopilotChatMessage(role="assistant", content="hi there"),
]
session = await upsert_chat_session(session)
@@ -41,7 +375,7 @@ async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user
"""set_turn_duration is a no-op when there are no assistant messages."""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
ChatMessage(role="user", content="hello"),
CopilotChatMessage(role="user", content="hello"),
]
session = await upsert_chat_session(session)

View File

@@ -151,8 +151,8 @@ class CoPilotProcessor:
This method is called once per worker thread to set up the async event
loop and initialize any required resources.
Database is accessed only through DatabaseManager, so we don't need to connect
to Prisma directly.
DB operations route through DatabaseManagerAsyncClient (RPC) via the
db_accessors pattern — no direct Prisma connection is needed here.
"""
configure_logging()
set_service_name("CoPilotExecutor")

View File

@@ -64,6 +64,7 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
sequence: int | None = None
duration_ms: int | None = None
@staticmethod
@@ -77,6 +78,7 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
sequence=prisma_message.sequence,
duration_ms=prisma_message.durationMs,
)

View File

@@ -15,6 +15,7 @@ from prisma.models import User as PrismaUser
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.db_accessors import user_db
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
@@ -409,9 +410,12 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
try:
user = await user_db().get_user_by_id(user_id)
except Exception:
raise _UserNotFoundError(user_id)
if user.subscription_tier:
return SubscriptionTier(user.subscription_tier)
raise _UserNotFoundError(user_id)

View File

@@ -401,66 +401,49 @@ class TestGetUserTier:
"""Clear the get_user_tier cache before each test."""
get_user_tier.cache_clear() # type: ignore[attr-defined]
def _mock_user_db(
self, subscription_tier: str | None = None, raises: Exception | None = None
):
"""Return a patched user_db() whose get_user_by_id behaves as specified."""
mock_db = AsyncMock()
if raises is not None:
mock_db.get_user_by_id = AsyncMock(side_effect=raises)
else:
mock_user = MagicMock()
mock_user.subscription_tier = subscription_tier
mock_db.get_user_by_id = AsyncMock(return_value=mock_user)
return mock_db
@pytest.mark.asyncio
async def test_returns_tier_from_db(self):
"""Should return the tier stored in the user record."""
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
mock_db = self._mock_user_db(subscription_tier="PRO")
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
tier = await get_user_tier(_USER)
assert tier == SubscriptionTier.PRO
@pytest.mark.asyncio
async def test_returns_default_when_user_not_found(self):
"""Should return DEFAULT_TIER when user is not in the DB."""
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
mock_db = self._mock_user_db(raises=Exception("not found"))
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_returns_default_when_tier_is_none(self):
"""Should return DEFAULT_TIER when subscriptionTier is None."""
mock_user = MagicMock()
mock_user.subscriptionTier = None
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
"""Should return DEFAULT_TIER when subscription_tier is None."""
mock_db = self._mock_user_db(subscription_tier=None)
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_returns_default_on_db_error(self):
"""Should fall back to DEFAULT_TIER when DB raises."""
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
mock_db = self._mock_user_db(raises=Exception("DB down"))
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
@@ -470,26 +453,14 @@ class TestGetUserTier:
Regression test: a transient DB failure previously cached DEFAULT_TIER
for 5 minutes, incorrectly downgrading higher-tier users until expiry.
"""
failing_prisma = AsyncMock()
failing_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=failing_prisma,
):
failing_db = self._mock_user_db(raises=Exception("DB down"))
with patch("backend.copilot.rate_limit.user_db", return_value=failing_db):
tier1 = await get_user_tier(_USER)
assert tier1 == DEFAULT_TIER
# Now DB recovers and returns PRO
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
ok_prisma = AsyncMock()
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=ok_prisma,
):
ok_db = self._mock_user_db(subscription_tier="PRO")
with patch("backend.copilot.rate_limit.user_db", return_value=ok_db):
tier2 = await get_user_tier(_USER)
# Should get PRO now — the error result was not cached
@@ -498,18 +469,9 @@ class TestGetUserTier:
@pytest.mark.asyncio
async def test_returns_default_on_invalid_tier_value(self):
"""Should fall back to DEFAULT_TIER when stored value is invalid."""
mock_user = MagicMock()
mock_user.subscriptionTier = "invalid-tier"
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
mock_db = self._mock_user_db(subscription_tier="invalid-tier")
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
@@ -522,26 +484,14 @@ class TestGetUserTier:
stale cached FREE tier for up to 5 minutes.
"""
# First call: user does not exist yet
missing_prisma = AsyncMock()
missing_prisma.find_unique = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=missing_prisma,
):
missing_db = self._mock_user_db(raises=Exception("not found"))
with patch("backend.copilot.rate_limit.user_db", return_value=missing_db):
tier1 = await get_user_tier(_USER)
assert tier1 == DEFAULT_TIER
# Second call: user now exists with PRO tier
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
ok_prisma = AsyncMock()
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=ok_prisma,
):
ok_db = self._mock_user_db(subscription_tier="PRO")
with patch("backend.copilot.rate_limit.user_db", return_value=ok_db):
tier2 = await get_user_tier(_USER)
# Should get PRO — the not-found result was not cached
@@ -598,20 +548,19 @@ class TestSetUserTier:
@pytest.mark.asyncio
async def test_cache_invalidated_after_set(self):
"""After set_user_tier, get_user_tier should query DB again (not cache)."""
# First, populate the cache with BUSINESS
# First, populate the cache with BUSINESS via user_db() mock
mock_db_biz = AsyncMock()
mock_user_biz = MagicMock()
mock_user_biz.subscriptionTier = "BUSINESS"
mock_prisma_get = AsyncMock()
mock_prisma_get.find_unique = AsyncMock(return_value=mock_user_biz)
mock_user_biz.subscription_tier = "BUSINESS"
mock_db_biz.get_user_by_id = AsyncMock(return_value=mock_user_biz)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma_get,
):
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db_biz):
tier_before = await get_user_tier(_USER)
assert tier_before == SubscriptionTier.BUSINESS
# Now set tier to ENTERPRISE (this should invalidate the cache)
# Now set tier to ENTERPRISE via PrismaUser.prisma (set_user_tier still
# uses Prisma directly since it's only called from admin API where Prisma
# is connected).
mock_prisma_set = AsyncMock()
mock_prisma_set.update = AsyncMock(return_value=None)
@@ -622,15 +571,12 @@ class TestSetUserTier:
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
# Now get_user_tier should hit DB again (cache was invalidated)
mock_db_ent = AsyncMock()
mock_user_ent = MagicMock()
mock_user_ent.subscriptionTier = "ENTERPRISE"
mock_prisma_get2 = AsyncMock()
mock_prisma_get2.find_unique = AsyncMock(return_value=mock_user_ent)
mock_user_ent.subscription_tier = "ENTERPRISE"
mock_db_ent.get_user_by_id = AsyncMock(return_value=mock_user_ent)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma_get2,
):
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db_ent):
tier_after = await get_user_tier(_USER)
assert tier_after == SubscriptionTier.ENTERPRISE

View File

@@ -29,6 +29,7 @@ from claude_agent_sdk import (
)
from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from opentelemetry import trace as otel_trace
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
@@ -64,7 +65,6 @@ from ..model import (
ChatSession,
get_chat_session,
maybe_append_user_message,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
@@ -83,11 +83,7 @@ from ..response_model import (
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
_generate_session_title,
_is_langfuse_configured,
)
from ..service import _build_system_prompt, _is_langfuse_configured, _update_title_async
from ..token_tracking import persist_and_record_usage
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
@@ -2372,8 +2368,26 @@ async def stream_chat_completion_sdk(
raise
finally:
# --- Close OTEL context ---
# --- Close OTEL context (with cost attributes) ---
if _otel_ctx is not None:
try:
span = otel_trace.get_current_span()
if span and span.is_recording():
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
span.set_attribute(
"gen_ai.usage.completion_tokens", turn_completion_tokens
)
span.set_attribute(
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
)
span.set_attribute(
"gen_ai.usage.cache_creation_tokens",
turn_cache_creation_tokens,
)
if turn_cost_usd is not None:
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
except Exception:
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
try:
_otel_ctx.__exit__(*sys.exc_info())
except Exception:
@@ -2391,6 +2405,8 @@ async def stream_chat_completion_sdk(
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=config.model,
provider="anthropic",
)
# --- Persist session messages ---
@@ -2495,18 +2511,3 @@ async def stream_chat_completion_sdk(
finally:
# Release stream lock to allow new streams for this session
await lock.release()
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
except Exception as e:
logger.warning("[SDK] Failed to update session title: %s", e)

View File

@@ -22,7 +22,12 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
from .model import (
ChatSessionInfo,
get_chat_session,
update_session_title,
upsert_chat_session,
)
logger = logging.getLogger(__name__)
@@ -202,6 +207,22 @@ async def _generate_session_title(
return None
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Generate and persist a session title in the background.
Shared by both the SDK and baseline execution paths.
"""
try:
title = await _generate_session_title(message, user_id, session_id)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug("Generated title for session %s", session_id)
except Exception as e:
logger.warning("Failed to update session title for %s: %s", session_id, e)
async def assign_user_to_session(
session_id: str,
user_id: str,

View File

@@ -4,17 +4,85 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
4. Write a PlatformCostLog entry for admin cost tracking.
This module extracts that common logic so both paths stay in sync.
"""
import asyncio
import logging
import math
import re
import threading
from backend.data.db_accessors import platform_cost_db
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
# Hold strong references to in-flight cost log tasks to prevent GC.
_pending_log_tasks: set[asyncio.Task[None]] = set()
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
# fire from the event loop thread; drain_pending_cost_logs iterates the set
# from any caller — the lock prevents RuntimeError from concurrent modification.
_pending_log_tasks_lock = threading.Lock()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
async def _safe_log() -> None:
async with _get_log_semaphore():
try:
await platform_cost_db().log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
with _pending_log_tasks_lock:
_pending_log_tasks.add(task)
def _remove(t: asyncio.Task[None]) -> None:
with _pending_log_tasks_lock:
_pending_log_tasks.discard(t)
task.add_done_callback(_remove)
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
# block/credential in the block_cost_config or credentials_store tables).
COPILOT_BLOCK_ID = "copilot"
COPILOT_CREDENTIAL_ID = "copilot_system"
def _copilot_block_name(log_prefix: str) -> str:
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
if match:
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
tag = log_prefix.strip(" []")
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
async def persist_and_record_usage(
*,
@@ -26,6 +94,8 @@ async def persist_and_record_usage(
cache_creation_tokens: int = 0,
log_prefix: str = "",
cost_usd: float | str | None = None,
model: str | None = None,
provider: str = "open_router",
) -> int:
"""Persist token usage to session and record for rate limiting.
@@ -38,6 +108,7 @@ async def persist_and_record_usage(
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
provider: Cost provider name (e.g. "anthropic", "open_router").
Returns:
The computed total_tokens (prompt + completion; cache excluded).
@@ -47,12 +118,13 @@ async def persist_and_record_usage(
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
if (
no_tokens = (
prompt_tokens <= 0
and completion_tokens <= 0
and cache_read_tokens <= 0
and cache_creation_tokens <= 0
):
)
if no_tokens and cost_usd is None:
return 0
# total_tokens = prompt + completion. Cache tokens are tracked
@@ -73,14 +145,14 @@ async def persist_and_record_usage(
if cache_read_tokens or cache_creation_tokens:
logger.info(
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
f"{log_prefix} Turn usage: uncached={prompt_tokens}, cache_read={cache_read_tokens},"
f" cache_create={cache_creation_tokens}, output={completion_tokens},"
f" total={total_tokens}, cost_usd={cost_usd}"
)
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
f"completion={completion_tokens}, total={total_tokens}"
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
f" total={total_tokens}"
)
if user_id:
@@ -93,6 +165,54 @@ async def persist_and_record_usage(
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
# Log to PlatformCostLog for admin cost dashboard.
# Include entries where cost_usd is set even if token count is 0
# (e.g. fully-cached Anthropic responses where only cache tokens
# accumulate a charge without incrementing total_tokens).
if user_id and (total_tokens > 0 or cost_usd is not None):
cost_float = None
if cost_usd is not None:
try:
val = float(cost_usd)
if math.isfinite(val) and val >= 0:
cost_float = val
except (ValueError, TypeError):
pass
cost_microdollars = usd_to_microdollars(cost_float)
session_id = session.session_id if session else None
if cost_float is not None:
tracking_type = "cost_usd"
tracking_amount = cost_float
else:
tracking_type = "tokens"
tracking_amount = total_tokens
_schedule_cost_log(
PlatformCostEntry(
user_id=user_id,
graph_exec_id=session_id,
block_id=COPILOT_BLOCK_ID,
block_name=_copilot_block_name(log_prefix),
provider=provider,
credential_id=COPILOT_CREDENTIAL_ID,
cost_microdollars=cost_microdollars,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
model=model,
tracking_type=tracking_type,
tracking_amount=tracking_amount,
metadata={
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
"cache_read_tokens": cache_read_tokens,
"cache_creation_tokens": cache_creation_tokens,
"source": "copilot",
},
)
)
return total_tokens

View File

@@ -4,6 +4,7 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
calling conventions, session persistence, and rate-limit recording.
"""
import asyncio
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
@@ -279,3 +280,290 @@ class TestRateLimitRecording:
completion_tokens=0,
)
mock_record.assert_not_awaited()
# ---------------------------------------------------------------------------
# PlatformCostLog integration
# ---------------------------------------------------------------------------
class TestPlatformCostLogging:
@pytest.mark.asyncio
async def test_logs_cost_entry_with_cost_usd(self):
"""When cost_usd is provided, tracking_type should be 'cost_usd'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=_make_session(),
user_id="user-cost",
prompt_tokens=200,
completion_tokens=100,
cost_usd=0.005,
model="gpt-4",
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cost"
assert entry.provider == "anthropic"
assert entry.model == "gpt-4"
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 200
assert entry.output_tokens == 100
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["tracking_amount"] == 0.005
assert entry.block_name == "copilot:SDK"
assert entry.graph_exec_id == "sess-test"
@pytest.mark.asyncio
async def test_logs_cost_entry_without_cost_usd(self):
"""When cost_usd is None, tracking_type should be 'tokens'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-tokens",
prompt_tokens=100,
completion_tokens=50,
log_prefix="[Baseline]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 150
assert entry.graph_exec_id is None
assert entry.block_name == "copilot:Baseline"
@pytest.mark.asyncio
async def test_skips_cost_log_when_no_user_id(self):
"""No PlatformCostLog entry when user_id is None."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
await asyncio.sleep(0)
mock_log.assert_not_awaited()
@pytest.mark.asyncio
async def test_cost_usd_invalid_string_falls_back_to_tokens(self):
"""Invalid cost_usd string should fall back to tokens tracking."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-invalid",
prompt_tokens=100,
completion_tokens=50,
cost_usd="not-a-number",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
assert entry.metadata["tracking_type"] == "tokens"
@pytest.mark.asyncio
async def test_cost_usd_string_number_is_parsed(self):
"""String-encoded cost_usd (e.g. from OpenRouter) should be parsed."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-str",
prompt_tokens=100,
completion_tokens=50,
cost_usd="0.01",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars == 10_000
assert entry.metadata["tracking_type"] == "cost_usd"
@pytest.mark.asyncio
async def test_empty_log_prefix_produces_copilot_block_name(self):
"""Empty log_prefix results in block_name='copilot'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-empty",
prompt_tokens=10,
completion_tokens=5,
log_prefix="",
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.block_name == "copilot"
@pytest.mark.asyncio
async def test_cache_tokens_included_in_metadata(self):
"""Cache token counts should be present in the metadata."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-cache",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=5000,
cache_creation_tokens=300,
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.metadata["cache_read_tokens"] == 5000
assert entry.metadata["cache_creation_tokens"] == 300
assert entry.metadata["source"] == "copilot"
@pytest.mark.asyncio
async def test_logs_cost_only_when_tokens_zero(self):
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-cached",
prompt_tokens=0,
completion_tokens=0,
cost_usd=0.005,
model="claude-3-5-sonnet",
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
# Guard: total_tokens == 0 but cost_usd is set — must still log
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cached"
assert entry.tracking_type == "cost_usd"
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 0
assert entry.output_tokens == 0
@pytest.mark.asyncio
async def test_negative_cost_usd_falls_back_to_tokens(self):
"""Negative cost_usd must be rejected — val >= 0 guard in persist_and_record_usage."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-negative",
prompt_tokens=100,
completion_tokens=50,
cost_usd=-0.01,
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
# Negative cost rejected — falls back to token-based tracking
assert entry.cost_microdollars is None
assert entry.metadata["tracking_type"] == "tokens"

View File

@@ -845,6 +845,7 @@ class WriteWorkspaceFileTool(BaseTool):
path=path,
mime_type=mime_type,
overwrite=overwrite,
metadata={"origin": "agent-created"},
)
# Build informative source label and message.

View File

@@ -142,3 +142,16 @@ def credit_db():
credit_db = get_database_manager_async_client()
return credit_db
def platform_cost_db():
if db.is_connected():
from backend.data import platform_cost as _platform_cost_db
platform_cost_db = _platform_cost_db
else:
from backend.util.clients import get_database_manager_async_client
platform_cost_db = get_database_manager_async_client()
return platform_cost_db

View File

@@ -96,6 +96,7 @@ from backend.data.notifications import (
remove_notifications_from_batch,
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.platform_cost import log_platform_cost
from backend.data.understanding import (
get_business_understanding,
upsert_business_understanding,
@@ -332,6 +333,9 @@ class DatabaseManager(AppService):
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
update_block_optimized_description = _(update_block_optimized_description)
# ============ Platform Cost Tracking ============ #
log_platform_cost = _(log_platform_cost)
# ============ CoPilot Chat Sessions ============ #
get_chat_session = _(chat_db.get_chat_session)
create_chat_session = _(chat_db.create_chat_session)
@@ -529,6 +533,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Block Descriptions ============ #
get_blocks_needing_optimization = d.get_blocks_needing_optimization
# ============ Platform Cost Tracking ============ #
log_platform_cost = d.log_platform_cost
# ============ CoPilot Chat Sessions ============ #
get_chat_session = d.get_chat_session
create_chat_session = d.create_chat_session

View File

@@ -333,26 +333,29 @@ class BaseGraph(GraphBaseMeta):
except Exception as e:
logger.error(f"Invalid {type_class}: {input_default}, {e}")
return {
"type": "object",
"properties": {
p.name: {
**{
k: v
for k, v in p.generate_schema().items()
if k not in ["description", "default"]
},
"secret": p.secret,
# Default value has to be set for advanced fields.
"advanced": p.advanced and p.value is not None,
"title": p.title or p.name,
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema_fields
},
"required": [p.name for p in schema_fields if p.value is None],
}
try:
return {
"type": "object",
"properties": {
p.name: {
**{
k: v
for k, v in p.generate_schema().items()
if k not in ["description", "default"]
},
"secret": p.secret,
# Default value has to be set for advanced fields.
"advanced": p.advanced and p.value is not None,
"title": p.title or p.name,
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema_fields
},
"required": [p.name for p in schema_fields if p.value is None],
}
except AttributeError as e:
raise ValueError(str(e)) from e
class GraphTriggerInfo(BaseModel):

View File

@@ -15,6 +15,7 @@ from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.data.graph import (
Graph,
GraphModel,
Link,
Node,
get_graph,
@@ -1460,3 +1461,21 @@ async def test_validate_graph_execution_permissions_library_wrong_version_denied
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
assert lib_where["agentGraphVersion"] == graph_version
# ============================================================================
# Tests for _generate_schema AttributeError → ValueError conversion
# ============================================================================
def test_generate_schema_raises_value_error_when_name_missing():
"""AgentInputBlock.Input constructed without 'name' should raise ValueError.
model_construct() skips validation, so the Input object is created without
a 'name' attribute. The dict comprehension in _generate_schema then hits an
AttributeError when it accesses p.name. That AttributeError must be caught
and re-raised as ValueError so the existing 400 handler in rest_api.py fires
instead of falling through to the 500 catch-all.
"""
with pytest.raises(ValueError):
GraphModel._generate_schema((AgentInputBlock.Input, {}))

View File

@@ -21,7 +21,7 @@ from typing import (
)
from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep
from prisma.enums import CreditTransactionType, OnboardingStep, SubscriptionTier
from pydantic import (
BaseModel,
ConfigDict,
@@ -54,7 +54,6 @@ class User(BaseModel):
"""Application-layer User model with snake_case convention."""
model_config = ConfigDict(
extra="forbid",
str_strip_whitespace=True,
)
@@ -104,6 +103,9 @@ class User(BaseModel):
description="User timezone (IANA timezone identifier or 'not-set')",
)
# Subscription / rate-limit tier
subscription_tier: SubscriptionTier | None = Field(default=None)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
@@ -158,6 +160,7 @@ class User(BaseModel):
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
subscription_tier=prisma_user.subscriptionTier,
)
@@ -819,6 +822,17 @@ class RefundRequest(BaseModel):
updated_at: datetime
ProviderCostType = Literal[
"cost_usd", # Actual USD cost reported by the provider
"tokens", # LLM token counts (sum of input + output)
"characters", # Per-character billing (TTS providers)
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
"walltime_seconds", # Per-second billing incl. queue/polling
"per_run", # Per-API-call billing with fixed cost
"items", # Per-item billing (lead/organization/result count)
]
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
@@ -838,32 +852,39 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
extra_cost: int = 0
extra_steps: int = 0
provider_cost: float | None = None
# Type of the provider-reported cost/usage captured above. When set
# by a block, resolve_tracking honors this directly instead of
# guessing from provider name.
provider_cost_type: Optional[ProviderCostType] = None
# Moderation fields
cleared_inputs: Optional[dict[str, list[str]]] = None
cleared_outputs: Optional[dict[str, list[str]]] = None
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
"""Mutate this instance by adding another NodeExecutionStats."""
"""Mutate this instance by adding another NodeExecutionStats.
Avoids calling model_dump() twice per merge (called on every
merge_stats() from ~20+ blocks); reads via getattr/vars instead.
"""
if not isinstance(other, NodeExecutionStats):
return NotImplemented
stats_dict = other.model_dump()
current_stats = self.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it
for key in type(other).model_fields:
value = getattr(other, key)
if value is None:
# Never overwrite an existing value with None
continue
current = getattr(self, key, None)
if current is None:
# Field doesn't exist yet or is None, just set it
setattr(self, key, value)
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
setattr(self, key, current_stats[key])
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
setattr(self, key, current_stats[key] + value)
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
setattr(self, key, current_stats[key])
elif isinstance(value, dict) and isinstance(current, dict):
current.update(value)
elif isinstance(value, (int, float)) and isinstance(current, (int, float)):
setattr(self, key, current + value)
elif isinstance(value, list) and isinstance(current, list):
current.extend(value)
else:
setattr(self, key, value)

View File

@@ -1,7 +1,7 @@
import pytest
from pydantic import SecretStr
from backend.data.model import HostScopedCredentials
from backend.data.model import HostScopedCredentials, NodeExecutionStats
class TestHostScopedCredentials:
@@ -166,3 +166,84 @@ class TestHostScopedCredentials:
)
assert creds.matches_url(test_url) == expected
class TestNodeExecutionStatsIadd:
def test_adds_numeric_fields(self):
a = NodeExecutionStats(input_token_count=100, output_token_count=50)
b = NodeExecutionStats(input_token_count=200, output_token_count=30)
a += b
assert a.input_token_count == 300
assert a.output_token_count == 80
def test_none_does_not_overwrite(self):
a = NodeExecutionStats(provider_cost=0.5, error="some error")
b = NodeExecutionStats(provider_cost=None, error=None)
a += b
assert a.provider_cost == 0.5
assert a.error == "some error"
def test_none_is_skipped_preserving_existing_value(self):
a = NodeExecutionStats(input_token_count=100)
b = NodeExecutionStats()
a += b
assert a.input_token_count == 100
def test_dict_fields_are_merged(self):
a = NodeExecutionStats(
cleared_inputs={"field1": ["val1"]},
)
b = NodeExecutionStats(
cleared_inputs={"field2": ["val2"]},
)
a += b
assert a.cleared_inputs == {"field1": ["val1"], "field2": ["val2"]}
def test_returns_self(self):
a = NodeExecutionStats()
b = NodeExecutionStats(input_token_count=10)
result = a.__iadd__(b)
assert result is a
def test_not_implemented_for_non_stats(self):
a = NodeExecutionStats()
result = a.__iadd__("not a stats") # type: ignore[arg-type]
assert result is NotImplemented
def test_error_none_does_not_clear_existing_error(self):
a = NodeExecutionStats(error="existing error")
b = NodeExecutionStats(error=None)
a += b
assert a.error == "existing error"
def test_provider_cost_none_does_not_clear_existing_cost(self):
a = NodeExecutionStats(provider_cost=0.05)
b = NodeExecutionStats(provider_cost=None)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_accumulates_when_both_set(self):
a = NodeExecutionStats(provider_cost=0.01)
b = NodeExecutionStats(provider_cost=0.02)
a += b
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
def test_provider_cost_first_write_from_none(self):
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost=0.05)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_type_first_write_from_none(self):
"""Writing provider_cost_type into a stats with None sets it."""
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost_type="characters")
a += b
assert a.provider_cost_type == "characters"
def test_provider_cost_type_none_does_not_overwrite(self):
"""A None provider_cost_type from other must not clear an existing value."""
a = NodeExecutionStats(provider_cost_type="tokens")
b = NodeExecutionStats()
a += b
assert a.provider_cost_type == "tokens"

View File

@@ -0,0 +1,378 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import Any
from prisma.models import PlatformCostLog as PrismaLog
from prisma.types import PlatformCostLogCreateInput
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
MICRODOLLARS_PER_USD = 1_000_000
# Dashboard query limits — keep in sync with the SQL queries below
MAX_PROVIDER_ROWS = 500
MAX_USER_ROWS = 100
# Default date range for dashboard queries when no start date is provided.
# Prevents full-table scans on large deployments.
DEFAULT_DASHBOARD_DAYS = 30
def usd_to_microdollars(cost_usd: float | None) -> int | None:
"""Convert a USD amount (float) to microdollars (int). None-safe."""
if cost_usd is None:
return None
return round(cost_usd * MICRODOLLARS_PER_USD)
class PlatformCostEntry(BaseModel):
user_id: str
graph_exec_id: str | None = None
node_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
block_id: str | None = None
block_name: str | None = None
provider: str
credential_id: str | None = None
cost_microdollars: int | None = None
input_tokens: int | None = None
output_tokens: int | None = None
data_size: int | None = None
duration: float | None = None
model: str | None = None
tracking_type: str | None = None
tracking_amount: float | None = None
metadata: dict[str, Any] | None = None
async def log_platform_cost(entry: PlatformCostEntry) -> None:
await PrismaLog.prisma().create(
data=PlatformCostLogCreateInput(
userId=entry.user_id,
graphExecId=entry.graph_exec_id,
nodeExecId=entry.node_exec_id,
graphId=entry.graph_id,
nodeId=entry.node_id,
blockId=entry.block_id,
blockName=entry.block_name,
# Normalize to lowercase so the (provider, createdAt) index is always
# used without LOWER() on the read side.
provider=entry.provider.lower(),
credentialId=entry.credential_id,
costMicrodollars=entry.cost_microdollars,
inputTokens=entry.input_tokens,
outputTokens=entry.output_tokens,
dataSize=entry.data_size,
duration=entry.duration,
model=entry.model,
trackingType=entry.tracking_type,
trackingAmount=entry.tracking_amount,
metadata=SafeJson(entry.metadata or {}),
)
)
# Bound the number of concurrent cost-log DB inserts to prevent unbounded
# task/connection growth under sustained load or DB slowness.
_log_semaphore = asyncio.Semaphore(50)
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
"""Fire-and-forget wrapper that never raises."""
try:
async with _log_semaphore:
await log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
def _mask_email(email: str | None) -> str | None:
"""Mask an email address to reduce PII exposure in admin API responses.
Turns 'user@example.com' into 'us***@example.com'.
Handles short local parts gracefully (e.g. 'a@b.com''a***@b.com').
"""
if not email:
return email
at = email.find("@")
if at < 0:
return "***"
local = email[:at]
domain = email[at:]
visible = local[:2] if len(local) >= 2 else local[:1]
return f"{visible}***{domain}"
class ProviderCostSummary(BaseModel):
provider: str
tracking_type: str | None = None
total_cost_microdollars: int
total_input_tokens: int
total_output_tokens: int
total_duration_seconds: float = 0.0
total_tracking_amount: float = 0.0
request_count: int
class UserCostSummary(BaseModel):
user_id: str | None = None
email: str | None = None
total_cost_microdollars: int
total_input_tokens: int
total_output_tokens: int
request_count: int
class CostLogRow(BaseModel):
id: str
created_at: datetime
user_id: str | None = None
email: str | None = None
graph_exec_id: str | None = None
node_exec_id: str | None = None
block_name: str
provider: str
tracking_type: str | None = None
cost_microdollars: int | None = None
input_tokens: int | None = None
output_tokens: int | None = None
duration: float | None = None
model: str | None = None
class PlatformCostDashboard(BaseModel):
by_provider: list[ProviderCostSummary]
by_user: list[UserCostSummary]
total_cost_microdollars: int
total_requests: int
total_users: int
def _build_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
) -> tuple[str, list[Any]]:
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
# Provider names are normalized to lowercase at write time so a plain
# equality check is sufficient and the (provider, createdAt) index is used.
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
return (" AND ".join(clauses) if clauses else "TRUE", params)
@cached(ttl_seconds=30)
async def get_platform_cost_dashboard(
start: datetime | None = None,
end: datetime | None = None,
provider: str | None = None,
user_id: str | None = None,
) -> PlatformCostDashboard:
"""Aggregate platform cost logs for the admin dashboard.
Note: by_provider rows are keyed on (provider, tracking_type). A single
provider can therefore appear in multiple rows if it has entries with
different billing models (e.g. "openai" with both "tokens" and "cost_usd"
if pricing is later added for some entries). Frontend treats each row
independently rather than as a provider primary key.
Defaults to the last DEFAULT_DASHBOARD_DAYS days when no start date is
provided to avoid full-table scans on large deployments.
"""
if start is None:
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_p, params_p = _build_where(start, end, provider, user_id, "p")
by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT
p."provider",
p."trackingType" AS tracking_type,
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider", p."trackingType"
ORDER BY total_cost DESC
LIMIT {MAX_PROVIDER_ROWS}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT
p."userId" AS user_id,
u."email",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_p}
GROUP BY p."userId", u."email"
ORDER BY total_cost DESC
LIMIT {MAX_USER_ROWS}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
),
)
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
total_cost = sum(r["total_cost"] for r in by_provider_rows)
total_requests = sum(r["request_count"] for r in by_provider_rows)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
provider=r["provider"],
tracking_type=r.get("tracking_type"),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
total_duration_seconds=r.get("total_duration", 0.0),
total_tracking_amount=r.get("total_tracking_amount", 0.0),
request_count=r["request_count"],
)
for r in by_provider_rows
],
by_user=[
UserCostSummary(
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
request_count=r["request_count"],
)
for r in by_user_rows
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
total_users=total_users,
)
async def get_platform_cost_logs(
start: datetime | None = None,
end: datetime | None = None,
provider: str | None = None,
user_id: str | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(start, end, provider, user_id, "p")
offset = (page - 1) * page_size
limit_idx = len(params) + 1
offset_idx = len(params) + 2
count_rows, rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT COUNT(*)::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_sql}
""",
*params,
),
query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx} OFFSET ${offset_idx}
""",
*params,
page_size,
offset,
),
)
total = count_rows[0]["cnt"] if count_rows else 0
logs = [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
duration=r.get("duration"),
model=r.get("model"),
)
for r in rows
]
return logs, total

View File

@@ -0,0 +1,79 @@
"""
Integration tests for platform cost logging.
These tests run actual database operations to verify that SafeJson metadata
round-trips correctly through Prisma — catching the DataError that occurred
when a plain Python dict was passed to the Prisma Json? field.
"""
import uuid
import pytest
from prisma.models import PlatformCostLog as PrismaLog
from prisma.models import User
from backend.util.json import SafeJson
from .platform_cost import PlatformCostEntry, log_platform_cost
@pytest.fixture
async def cost_log_user():
"""Create a throw-away user and clean up cost logs after the test."""
user_id = str(uuid.uuid4())
await User.prisma().create(
data={
"id": user_id,
"email": f"cost-test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
)
yield user_id
await PrismaLog.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete(where={"id": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_log_platform_cost_metadata_round_trip(cost_log_user):
"""
Verify that SafeJson metadata is persisted and read back correctly.
This test would have caught the DataError that silently swallowed all cost
log writes when a plain Python dict was passed to the Prisma Json? field.
"""
user_id = cost_log_user
entry = PlatformCostEntry(
user_id=user_id,
block_name="TestBlock",
provider="openai",
cost_microdollars=5000,
input_tokens=100,
output_tokens=50,
model="gpt-4",
metadata={"key": "val", "nested": {"x": 1}},
)
await log_platform_cost(entry)
rows = await PrismaLog.prisma().find_many(where={"userId": user_id})
assert len(rows) == 1
assert rows[0].metadata == {"key": "val", "nested": {"x": 1}}
assert rows[0].provider == "openai"
assert rows[0].costMicrodollars == 5000
@pytest.mark.asyncio(loop_scope="session")
async def test_log_platform_cost_metadata_none(cost_log_user):
"""Verify that None metadata falls back to {} (not a DataError)."""
user_id = cost_log_user
entry = PlatformCostEntry(
user_id=user_id,
block_name="TestBlock",
provider="anthropic",
metadata=None,
)
await log_platform_cost(entry)
rows = await PrismaLog.prisma().find_many(where={"userId": user_id})
assert len(rows) == 1
assert rows[0].metadata == {}

View File

@@ -0,0 +1,286 @@
"""Unit tests for helpers and async functions in platform_cost module."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from prisma import Json
from backend.util.json import SafeJson
from .platform_cost import (
PlatformCostEntry,
_build_where,
_mask_email,
get_platform_cost_dashboard,
get_platform_cost_logs,
log_platform_cost,
log_platform_cost_safe,
)
class TestMaskEmail:
def test_typical_email(self):
assert _mask_email("user@example.com") == "us***@example.com"
def test_short_local_part(self):
assert _mask_email("a@b.com") == "a***@b.com"
def test_none_returns_none(self):
assert _mask_email(None) is None
def test_empty_string_returns_empty(self):
assert _mask_email("") == ""
def test_no_at_sign_returns_stars(self):
assert _mask_email("notanemail") == "***"
def test_two_char_local(self):
assert _mask_email("ab@domain.org") == "ab***@domain.org"
class TestBuildWhere:
def test_no_filters_returns_true(self):
sql, params = _build_where(None, None, None, None)
assert sql == "TRUE"
assert params == []
def test_start_only(self):
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
sql, params = _build_where(dt, None, None, None)
assert '"createdAt" >= $1::timestamptz' in sql
assert params == [dt]
def test_end_only(self):
dt = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_where(None, dt, None, None)
assert '"createdAt" <= $1::timestamptz' in sql
assert params == [dt]
def test_provider_only(self):
# Provider names are normalized to lowercase at write time, so the
# filter uses a plain equality check. The input is also lowercased so
# "OpenAI" and "openai" both match stored rows.
sql, params = _build_where(None, None, "OpenAI", None)
assert '"provider" = $1' in sql
assert params == ["openai"]
def test_user_id_only(self):
sql, params = _build_where(None, None, None, "user-123")
assert '"userId" = $1' in sql
assert params == ["user-123"]
def test_all_filters(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_where(start, end, "Anthropic", "u1")
assert "$1" in sql
assert "$2" in sql
assert "$3" in sql
assert "$4" in sql
assert len(params) == 4
# Provider is lowercased at filter time to match stored lowercase values.
assert params == [start, end, "anthropic", "u1"]
def test_table_alias(self):
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
sql, params = _build_where(dt, None, None, None, table_alias="p")
assert 'p."createdAt"' in sql
assert params == [dt]
def test_clauses_joined_with_and(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, _ = _build_where(start, end, None, None)
assert " AND " in sql
def _make_entry(**overrides: object) -> PlatformCostEntry:
return PlatformCostEntry.model_validate(
{
"user_id": "user-1",
"block_id": "block-1",
"block_name": "TestBlock",
"provider": "openai",
"credential_id": "cred-1",
**overrides,
}
)
class TestLogPlatformCost:
@pytest.mark.asyncio
async def test_creates_prisma_record(self):
mock_create = AsyncMock()
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
mock_prisma.return_value.create = mock_create
entry = _make_entry(
input_tokens=100,
output_tokens=50,
cost_microdollars=5000,
model="gpt-4",
metadata={"key": "val"},
)
await log_platform_cost(entry)
mock_create.assert_awaited_once()
data = mock_create.call_args[1]["data"]
assert data["userId"] == "user-1"
assert data["blockName"] == "TestBlock"
assert data["provider"] == "openai"
# metadata must be wrapped in SafeJson (a prisma.Json subclass), not a plain dict
assert isinstance(data["metadata"], Json)
@pytest.mark.asyncio
async def test_metadata_none_passes_none(self):
mock_create = AsyncMock()
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
mock_prisma.return_value.create = mock_create
entry = _make_entry(metadata=None)
await log_platform_cost(entry)
data = mock_create.call_args[1]["data"]
# None falls back to SafeJson({}) so Prisma always gets a valid Json value
assert isinstance(data["metadata"], Json)
assert data["metadata"] == SafeJson({})
class TestLogPlatformCostSafe:
@pytest.mark.asyncio
async def test_does_not_raise_on_error(self):
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
mock_prisma.return_value.create = AsyncMock(
side_effect=RuntimeError("DB down")
)
entry = _make_entry()
await log_platform_cost_safe(entry)
@pytest.mark.asyncio
async def test_succeeds_when_no_error(self):
mock_create = AsyncMock()
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
mock_prisma.return_value.create = mock_create
entry = _make_entry()
await log_platform_cost_safe(entry)
mock_create.assert_awaited_once()
class TestGetPlatformCostDashboard:
def setup_method(self):
# @cached stores results in-process; clear between tests to avoid bleed.
get_platform_cost_dashboard.cache_clear()
@pytest.mark.asyncio
async def test_returns_dashboard_with_data(self):
provider_rows = [
{
"provider": "openai",
"tracking_type": "tokens",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"total_duration": 10.5,
"request_count": 3,
}
]
user_rows = [
{
"user_id": "u1",
"email": "a@b.com",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"request_count": 3,
}
]
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
assert dashboard.total_requests == 3
assert dashboard.total_users == 1
assert len(dashboard.by_provider) == 1
assert dashboard.by_provider[0].provider == "openai"
assert dashboard.by_provider[0].tracking_type == "tokens"
assert dashboard.by_provider[0].total_duration_seconds == 10.5
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].email == "a***@b.com"
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
assert dashboard.total_requests == 0
assert dashboard.total_users == 0
assert dashboard.by_provider == []
assert dashboard.by_user == []
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 3
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql
class TestGetPlatformCostLogs:
@pytest.mark.asyncio
async def test_returns_logs_and_total(self):
count_rows = [{"cnt": 1}]
log_rows = [
{
"id": "log-1",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": "a@b.com",
"graph_exec_id": "g1",
"node_exec_id": "n1",
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 5000,
"input_tokens": 100,
"output_tokens": 50,
"duration": 1.5,
"model": "gpt-4",
}
]
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=1, page_size=10)
assert total == 1
assert len(logs) == 1
assert logs[0].id == "log-1"
assert logs[0].provider == "openai"
assert logs[0].model == "gpt-4"
@pytest.mark.asyncio
async def test_returns_empty_when_no_data(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0
assert logs == []
@pytest.mark.asyncio
async def test_pagination_offset(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
second_call_args = mock_query.call_args_list[1][0]
assert 25 in second_call_args # page_size
assert 50 in second_call_args # offset = (3-1) * 25
@pytest.mark.asyncio
async def test_empty_count_returns_zero(self):
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0

View File

@@ -0,0 +1,291 @@
"""Helpers for platform cost tracking on system-credential block executions."""
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, cast
from backend.blocks._base import Block, BlockSchema
from backend.copilot.token_tracking import _pending_log_tasks as _copilot_tasks
from backend.copilot.token_tracking import (
_pending_log_tasks_lock as _copilot_tasks_lock,
)
from backend.data.execution import NodeExecutionEntry
from backend.data.model import NodeExecutionStats
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import is_system_credential
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from backend.data.db_manager import DatabaseManagerAsyncClient
logger = logging.getLogger(__name__)
# Provider groupings by billing model — used when the block didn't explicitly
# declare stats.provider_cost_type and we fall back to provider-name
# heuristics. Values match ProviderName enum values.
_CHARACTER_BILLED_PROVIDERS = frozenset(
{ProviderName.D_ID.value, ProviderName.ELEVENLABS.value}
)
_WALLTIME_BILLED_PROVIDERS = frozenset(
{
ProviderName.FAL.value,
ProviderName.REVID.value,
ProviderName.REPLICATE.value,
}
)
# Hold strong references to in-flight log tasks so the event loop doesn't
# garbage-collect them mid-execution. Tasks remove themselves on completion.
# _pending_log_tasks_lock guards all reads and writes: worker threads call
# discard() via done callbacks while drain_pending_cost_logs() iterates.
_pending_log_tasks: set[asyncio.Task] = set()
_pending_log_tasks_lock = threading.Lock()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads. Key by loop instance
# so each executor worker thread gets its own semaphore.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
"""Await all in-flight cost log tasks with a timeout.
Drains both the executor cost log tasks (_pending_log_tasks in this module,
used for block execution cost tracking via DatabaseManagerAsyncClient) and
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
copilot LLM turns via platform_cost_db()).
Call this during graceful shutdown to flush pending INSERT tasks before
the process exits. Tasks that don't complete within `timeout` seconds are
abandoned and their failures are already logged by _safe_log.
"""
# asyncio.wait() requires all tasks to belong to the running event loop.
# _pending_log_tasks is shared across executor worker threads (each with
# its own loop), so filter to only tasks owned by the current loop.
# Acquire the lock to take a consistent snapshot (worker threads call
# discard() via done callbacks concurrently with this iteration).
current_loop = asyncio.get_running_loop()
with _pending_log_tasks_lock:
all_pending = [t for t in _pending_log_tasks if t.get_loop() is current_loop]
if all_pending:
logger.info("Draining %d executor cost log task(s)", len(all_pending))
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d executor cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
with _copilot_tasks_lock:
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
if copilot_pending:
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d copilot cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
def _schedule_log(
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
) -> None:
async def _safe_log() -> None:
async with _get_log_semaphore():
try:
await db_client.log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
with _pending_log_tasks_lock:
_pending_log_tasks.add(task)
def _remove(t: asyncio.Task) -> None:
with _pending_log_tasks_lock:
_pending_log_tasks.discard(t)
task.add_done_callback(_remove)
def _extract_model_name(raw: str | dict | None) -> str | None:
"""Return a string model name from a block input field, or None.
Handles str (returned as-is), dict (e.g. an enum wrapper, skipped), and
None (no model field). Unexpected types are coerced to str as a fallback.
"""
if raw is None:
return None
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
return None
return str(raw)
def resolve_tracking(
provider: str,
stats: NodeExecutionStats,
input_data: dict[str, Any],
) -> tuple[str, float]:
"""Return (tracking_type, tracking_amount) based on provider billing model.
Preference order:
1. Block-declared: if the block set `provider_cost_type` on its stats,
honor it directly (paired with `provider_cost` as the amount).
2. Heuristic fallback: infer from `provider_cost`/token counts, then
from provider name for per-character / per-second billing.
"""
# 1. Block explicitly declared its cost type (only when an amount is present)
if stats.provider_cost_type and stats.provider_cost is not None:
return stats.provider_cost_type, max(0.0, stats.provider_cost)
# 2. Provider returned actual USD cost (OpenRouter, Exa)
if stats.provider_cost is not None:
return "cost_usd", max(0.0, stats.provider_cost)
# 3. LLM providers: track by tokens
if stats.input_token_count or stats.output_token_count:
return "tokens", float(
(stats.input_token_count or 0) + (stats.output_token_count or 0)
)
# 4. Provider-specific billing heuristics
# TTS: billed per character of input text
if provider == ProviderName.UNREAL_SPEECH.value:
text = input_data.get("text", "")
return "characters", float(len(text)) if isinstance(text, str) else 0.0
# D-ID + ElevenLabs voice: billed per character of script
if provider in _CHARACTER_BILLED_PROVIDERS:
text = (
input_data.get("script_input", "")
or input_data.get("text", "")
or input_data.get("script", "") # VideoNarrationBlock uses `script`
)
return "characters", float(len(text)) if isinstance(text, str) else 0.0
# E2B: billed per second of sandbox time
if provider == ProviderName.E2B.value:
return "sandbox_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
# Video/image gen: walltime includes queue + generation + polling
if provider in _WALLTIME_BILLED_PROVIDERS:
return "walltime_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
# All billed per API call - count 1 per block execution.
return "per_run", 1.0
async def log_system_credential_cost(
node_exec: NodeExecutionEntry,
block: Block,
stats: NodeExecutionStats,
db_client: "DatabaseManagerAsyncClient",
) -> None:
"""Check if a system credential was used and log the platform cost.
Routes through DatabaseManagerAsyncClient so the write goes via the
message-passing DB service rather than calling Prisma directly (which
is not connected in the executor process).
Logs only the first matching system credential field (one log per
execution). Any unexpected error is caught and logged — cost logging
is strictly best-effort and must never disrupt block execution.
Note: costMicrodollars is left null for providers that don't return
a USD cost. The credit_cost in metadata captures our internal credit
charge as a proxy.
"""
try:
if node_exec.execution_context.dry_run:
return
input_data = node_exec.inputs
input_model = cast(type[BlockSchema], block.input_schema)
for field_name in input_model.get_credentials_fields():
cred_data = input_data.get(field_name)
if not cred_data or not isinstance(cred_data, dict):
continue
cred_id = cred_data.get("id", "")
if not cred_id or not is_system_credential(cred_id):
continue
model_name = _extract_model_name(input_data.get("model"))
credit_cost, _ = block_usage_cost(block=block, input_data=input_data)
provider_name = cred_data.get("provider", "unknown")
tracking_type, tracking_amount = resolve_tracking(
provider=provider_name,
stats=stats,
input_data=input_data,
)
# Only treat provider_cost as USD when the tracking type says so.
# For other types (items, characters, per_run, ...) the
# provider_cost field holds the raw amount, not a dollar value.
# Use tracking_amount (the normalized value from resolve_tracking)
# rather than raw stats.provider_cost to avoid unit mismatches.
cost_microdollars = None
if tracking_type == "cost_usd":
cost_microdollars = usd_to_microdollars(tracking_amount)
meta: dict[str, Any] = {
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
}
if credit_cost is not None:
meta["credit_cost"] = credit_cost
if stats.provider_cost is not None:
# Use 'provider_cost_raw' — the value's unit varies by tracking
# type (USD for cost_usd, count for items/characters/per_run, etc.)
meta["provider_cost_raw"] = stats.provider_cost
_schedule_log(
db_client,
PlatformCostEntry(
user_id=node_exec.user_id,
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
graph_id=node_exec.graph_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block_name=block.name,
provider=provider_name,
credential_id=cred_id,
cost_microdollars=cost_microdollars,
input_tokens=stats.input_token_count,
output_tokens=stats.output_token_count,
data_size=stats.output_size if stats.output_size > 0 else None,
duration=stats.walltime if stats.walltime > 0 else None,
model=model_name,
tracking_type=tracking_type,
tracking_amount=tracking_amount,
metadata=meta,
),
)
return # One log per execution is enough
except Exception:
logger.exception("log_system_credential_cost failed unexpectedly")

View File

@@ -45,6 +45,10 @@ from backend.data.notifications import (
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cost_tracking import (
drain_pending_cost_logs,
log_system_credential_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.util import json
@@ -303,9 +307,18 @@ async def execute_node(
# Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items():
# Dry-run platform credentials bypass the credential store
# Dry-run platform credentials bypass the credential store.
# Keep the existing credential metadata so _execute's input_schema(**...)
# doesn't fail on the required field. If no metadata is present,
# synthesize a minimal placeholder from the platform credentials.
if _dry_run_creds is not None:
input_data[field_name] = None
if input_data.get(field_name) is None:
input_data[field_name] = {
"id": _dry_run_creds.id,
"provider": _dry_run_creds.provider,
"type": _dry_run_creds.type,
"title": _dry_run_creds.title,
}
extra_exec_kwargs[field_name] = _dry_run_creds
continue
@@ -692,6 +705,15 @@ class ExecutionProcessor:
stats=graph_stats,
)
# Log platform cost if system credentials were used (only on success)
if status == ExecutionStatus.COMPLETED:
await log_system_credential_cost(
node_exec=node_exec,
block=node.block,
stats=execution_stats,
db_client=db_client,
)
return execution_stats
@async_time_measured
@@ -2044,6 +2066,18 @@ class ExecutionManager(AppProcess):
prefix + " [cancel-consumer]",
)
# Drain any in-flight cost log tasks before exit so we don't silently
# drop INSERT operations during deployments.
loop = getattr(self, "node_execution_loop", None)
if loop is not None and loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
drain_pending_cost_logs(), loop
).result(timeout=10)
logger.info(f"{prefix} ✅ Cost log tasks drained")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Failed to drain cost log tasks: {e}")
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()

View File

@@ -0,0 +1,623 @@
"""Unit tests for resolve_tracking and log_system_credential_cost."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext, NodeExecutionEntry
from backend.data.model import NodeExecutionStats
from backend.executor.cost_tracking import (
drain_pending_cost_logs,
log_system_credential_cost,
resolve_tracking,
)
# ---------------------------------------------------------------------------
# resolve_tracking
# ---------------------------------------------------------------------------
class TestResolveTracking:
def _stats(self, **overrides: Any) -> NodeExecutionStats:
return NodeExecutionStats(**overrides)
def test_provider_cost_returns_cost_usd(self):
stats = self._stats(provider_cost=0.0042)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "cost_usd"
assert amt == 0.0042
def test_token_counts_return_tokens(self):
stats = self._stats(input_token_count=300, output_token_count=100)
tt, amt = resolve_tracking("anthropic", stats, {})
assert tt == "tokens"
assert amt == 400.0
def test_token_counts_only_input(self):
stats = self._stats(input_token_count=500)
tt, amt = resolve_tracking("groq", stats, {})
assert tt == "tokens"
assert amt == 500.0
def test_unreal_speech_returns_characters(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": "Hello world"})
assert tt == "characters"
assert amt == 11.0
def test_unreal_speech_empty_text(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": ""})
assert tt == "characters"
assert amt == 0.0
def test_unreal_speech_non_string_text(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": 123})
assert tt == "characters"
assert amt == 0.0
def test_d_id_uses_script_input(self):
stats = self._stats()
tt, amt = resolve_tracking("d_id", stats, {"script_input": "Hello"})
assert tt == "characters"
assert amt == 5.0
def test_elevenlabs_uses_text(self):
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Say this"})
assert tt == "characters"
assert amt == 8.0
def test_elevenlabs_fallback_to_text_when_no_script_input(self):
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Fallback text"})
assert tt == "characters"
assert amt == 13.0
def test_elevenlabs_uses_script_field(self):
"""VideoNarrationBlock (elevenlabs) uses `script` field, not script_input/text."""
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"script": "Narration"})
assert tt == "characters"
assert amt == 9.0
def test_block_declared_cost_type_items(self):
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "items"
assert amt == 5.0
def test_block_declared_cost_type_characters(self):
"""TTS block can declare characters directly, bypassing input_data lookup."""
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
tt, amt = resolve_tracking("unreal_speech", stats, {})
assert tt == "characters"
assert amt == 42.0
def test_block_declared_cost_type_wins_over_tokens(self):
"""provider_cost_type takes precedence over token-based heuristic."""
stats = self._stats(
provider_cost=1.0,
provider_cost_type="per_run",
input_token_count=500,
)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "per_run"
assert amt == 1.0
def test_e2b_returns_sandbox_seconds(self):
stats = self._stats(walltime=45.123)
tt, amt = resolve_tracking("e2b", stats, {})
assert tt == "sandbox_seconds"
assert amt == 45.123
def test_e2b_no_walltime(self):
stats = self._stats(walltime=0)
tt, amt = resolve_tracking("e2b", stats, {})
assert tt == "sandbox_seconds"
assert amt == 0.0
def test_fal_returns_walltime(self):
stats = self._stats(walltime=12.5)
tt, amt = resolve_tracking("fal", stats, {})
assert tt == "walltime_seconds"
assert amt == 12.5
def test_revid_returns_walltime(self):
stats = self._stats(walltime=60.0)
tt, amt = resolve_tracking("revid", stats, {})
assert tt == "walltime_seconds"
assert amt == 60.0
def test_replicate_returns_walltime(self):
stats = self._stats(walltime=30.0)
tt, amt = resolve_tracking("replicate", stats, {})
assert tt == "walltime_seconds"
assert amt == 30.0
def test_unknown_provider_returns_per_run(self):
stats = self._stats()
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "per_run"
assert amt == 1.0
def test_negative_provider_cost_clamped_to_zero(self):
"""Negative provider_cost values must be clamped to 0."""
stats = self._stats(provider_cost=-0.005)
tt, amt = resolve_tracking("openrouter", stats, {})
assert tt == "cost_usd"
assert amt == 0.0
def test_negative_block_declared_cost_clamped_to_zero(self):
"""Negative block-declared cost must also be clamped to 0."""
stats = self._stats(provider_cost=-1.0, provider_cost_type="items")
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "items"
assert amt == 0.0
def test_provider_cost_takes_precedence_over_tokens(self):
stats = self._stats(
provider_cost=0.01, input_token_count=500, output_token_count=200
)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "cost_usd"
assert amt == 0.01
def test_provider_cost_zero_is_not_none(self):
"""provider_cost=0.0 is falsy but should still be tracked as cost_usd
(e.g. free-tier or fully-cached responses from OpenRouter)."""
stats = self._stats(provider_cost=0.0)
tt, amt = resolve_tracking("open_router", stats, {})
assert tt == "cost_usd"
assert amt == 0.0
def test_tokens_take_precedence_over_provider_specific(self):
stats = self._stats(input_token_count=100, walltime=10.0)
tt, amt = resolve_tracking("fal", stats, {})
assert tt == "tokens"
assert amt == 100.0
# ---------------------------------------------------------------------------
# log_system_credential_cost
# ---------------------------------------------------------------------------
def _make_db_client() -> MagicMock:
db_client = MagicMock()
db_client.log_platform_cost = AsyncMock()
return db_client
def _make_block(has_credentials: bool = True) -> MagicMock:
block = MagicMock()
block.name = "TestBlock"
input_schema = MagicMock()
if has_credentials:
input_schema.get_credentials_fields.return_value = {"credentials": MagicMock()}
else:
input_schema.get_credentials_fields.return_value = {}
block.input_schema = input_schema
return block
def _make_node_exec(
inputs: dict | None = None,
dry_run: bool = False,
) -> NodeExecutionEntry:
return NodeExecutionEntry(
user_id="user-1",
graph_exec_id="gx-1",
graph_id="g-1",
graph_version=1,
node_exec_id="nx-1",
node_id="n-1",
block_id="b-1",
inputs=inputs or {},
execution_context=ExecutionContext(dry_run=dry_run),
)
class TestLogSystemCredentialCost:
@pytest.mark.asyncio
async def test_skips_dry_run(self):
db_client = _make_db_client()
node_exec = _make_node_exec(dry_run=True)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_no_credential_fields(self):
db_client = _make_db_client()
node_exec = _make_node_exec(inputs={})
block = _make_block(has_credentials=False)
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_cred_data_missing(self):
db_client = _make_db_client()
node_exec = _make_node_exec(inputs={})
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_not_system_credential(self):
db_client = _make_db_client()
with patch(
"backend.executor.cost_tracking.is_system_credential",
return_value=False,
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "user-cred-123", "provider": "openai"},
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_logs_with_system_credential(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(10, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-1", "provider": "openai"},
"model": "gpt-4",
}
)
block = _make_block()
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
db_client.log_platform_cost.assert_awaited_once()
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.user_id == "user-1"
assert entry.provider == "openai"
assert entry.block_name == "TestBlock"
assert entry.model == "gpt-4"
assert entry.input_tokens == 500
assert entry.output_tokens == 200
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 700.0
assert entry.metadata["credit_cost"] == 10
@pytest.mark.asyncio
async def test_logs_with_provider_cost(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(5, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-2", "provider": "open_router"},
}
)
block = _make_block()
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.cost_microdollars == 1500
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["provider_cost_raw"] == 0.0015
@pytest.mark.asyncio
async def test_model_name_enum_converted_to_str(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
from enum import Enum
class FakeModel(Enum):
GPT4 = "gpt-4"
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
"model": FakeModel.GPT4,
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.model == "FakeModel.GPT4"
@pytest.mark.asyncio
async def test_model_name_dict_becomes_none(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
"model": {"nested": "value"},
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.model is None
@pytest.mark.asyncio
async def test_does_not_raise_when_block_usage_cost_raises(self):
"""log_system_credential_cost must swallow exceptions from block_usage_cost."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
side_effect=RuntimeError("pricing lookup failed"),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
}
)
block = _make_block()
stats = NodeExecutionStats()
# Should not raise — outer except must catch block_usage_cost error
await log_system_credential_cost(node_exec, block, stats, db_client)
@pytest.mark.asyncio
async def test_round_instead_of_int_for_microdollars(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
}
)
block = _make_block()
# 0.0015 * 1_000_000 = 1499.9999999... with float math
# round() should give 1500, int() would give 1499
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.cost_microdollars == 1500
@pytest.mark.asyncio
async def test_per_run_metadata_has_no_provider_cost_raw(self):
"""For per-run providers (google_maps etc), provider_cost_raw is absent
from metadata since stats.provider_cost is None."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "google_maps"},
}
)
block = _make_block()
stats = NodeExecutionStats() # no provider_cost
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.tracking_type == "per_run"
assert "provider_cost_raw" not in (entry.metadata or {})
# ---------------------------------------------------------------------------
# merge_stats accumulation
# ---------------------------------------------------------------------------
class TestMergeStats:
"""Tests for NodeExecutionStats accumulation via += (used by Block.merge_stats)."""
def test_accumulates_output_size(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(output_size=10)
stats += NodeExecutionStats(output_size=25)
assert stats.output_size == 35
def test_accumulates_tokens(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(input_token_count=100, output_token_count=50)
stats += NodeExecutionStats(input_token_count=200, output_token_count=150)
assert stats.input_token_count == 300
assert stats.output_token_count == 200
def test_preserves_provider_cost(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(provider_cost=0.005)
stats += NodeExecutionStats(output_size=10)
assert stats.provider_cost == 0.005
assert stats.output_size == 10
def test_provider_cost_accumulates(self):
"""Multiple merge_stats with provider_cost should sum (multi-round
tool-calling in copilot / retries can report cost separately)."""
stats = NodeExecutionStats()
stats += NodeExecutionStats(provider_cost=0.001)
stats += NodeExecutionStats(provider_cost=0.002)
stats += NodeExecutionStats(provider_cost=0.003)
assert stats.provider_cost == pytest.approx(0.006)
def test_provider_cost_none_does_not_overwrite(self):
"""A None provider_cost must not wipe a previously-set value."""
stats = NodeExecutionStats(provider_cost=0.01)
stats += NodeExecutionStats() # provider_cost=None by default
assert stats.provider_cost == 0.01
def test_provider_cost_type_last_write_wins(self):
"""provider_cost_type is a Literal — last set value wins on merge."""
stats = NodeExecutionStats(provider_cost_type="tokens")
stats += NodeExecutionStats(provider_cost_type="items")
assert stats.provider_cost_type == "items"
# ---------------------------------------------------------------------------
# on_node_execution -> log_system_credential_cost integration
# ---------------------------------------------------------------------------
class TestManagerCostTrackingIntegration:
@pytest.mark.asyncio
async def test_log_called_with_accumulated_stats(self):
"""Verify that log_system_credential_cost receives stats that could
have been accumulated by merge_stats across multiple yield steps."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(5, None),
),
):
stats = NodeExecutionStats()
stats += NodeExecutionStats(output_size=10, input_token_count=100)
stats += NodeExecutionStats(output_size=25, input_token_count=200)
assert stats.output_size == 35
assert stats.input_token_count == 300
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-acc", "provider": "openai"},
"model": "gpt-4",
}
)
block = _make_block()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
db_client.log_platform_cost.assert_awaited_once()
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.input_tokens == 300
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_amount"] == 300.0
@pytest.mark.asyncio
async def test_skips_cost_log_when_status_is_failed(self):
"""Manager only calls log_system_credential_cost on COMPLETED status.
This test verifies the guard condition `if status == COMPLETED` directly:
calling log_system_credential_cost only happens on success, never on
FAILED or ERROR executions.
"""
from backend.data.execution import ExecutionStatus
db_client = _make_db_client()
node_exec = _make_node_exec(
inputs={"credentials": {"id": "sys-cred", "provider": "openai"}}
)
block = _make_block()
stats = NodeExecutionStats(input_token_count=100)
# Simulate the manager guard: only call on COMPLETED
status = ExecutionStatus.FAILED
if status == ExecutionStatus.COMPLETED:
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
# ---------------------------------------------------------------------------
# drain_pending_cost_logs
# ---------------------------------------------------------------------------
class TestDrainPendingCostLogs:
@pytest.mark.asyncio
async def test_drain_empty_set_completes(self):
"""drain_pending_cost_logs should succeed silently with no pending tasks."""
# Ensure both pending task sets are empty before calling drain
import backend.copilot.token_tracking as tt
import backend.executor.cost_tracking as ct
ct._pending_log_tasks.clear()
tt._pending_log_tasks.clear()
# Should not raise
await drain_pending_cost_logs(timeout=1.0)
@pytest.mark.asyncio
async def test_drain_awaits_in_flight_tasks(self):
"""drain_pending_cost_logs waits for tasks on the current loop."""
import backend.executor.cost_tracking as ct
finished = []
async def _slow():
await asyncio.sleep(0)
finished.append(1)
task = asyncio.ensure_future(_slow())
with ct._pending_log_tasks_lock:
ct._pending_log_tasks.add(task)
task.add_done_callback(lambda t: ct._pending_log_tasks.discard(t))
await drain_pending_cost_logs(timeout=2.0)
assert finished == [1], "drain_pending_cost_logs should have awaited the task"

View File

@@ -18,6 +18,7 @@ from backend.executor.simulator import (
_truncate_input_values,
_truncate_value,
build_simulation_prompt,
get_dry_run_credentials,
prepare_dry_run,
simulate_block,
)
@@ -234,6 +235,42 @@ class TestPrepareDryRun:
assert result is None
class TestGetDryRunCredentials:
"""get_dry_run_credentials pops _dry_run_api_key and returns APIKeyCredentials.
The returned object must have fields that can be serialised into a valid
CredentialsMetaInput placeholder dict for manager.py's schema-construction fix
(Bug: manager.py nullified input_data[field_name] = None, which caused
_execute's input_schema(**...) to fail because required credential fields were
missing after the None-filter pass).
"""
def test_returns_credentials_when_key_present(self) -> None:
input_data = {"_dry_run_api_key": "sk-or-test", "other": "val"}
creds = get_dry_run_credentials(input_data)
assert creds is not None
assert creds.api_key.get_secret_value() == "sk-or-test"
# key is consumed from input_data
assert "_dry_run_api_key" not in input_data
def test_returns_none_when_key_absent(self) -> None:
input_data: dict = {"other": "val"}
creds = get_dry_run_credentials(input_data)
assert creds is None
def test_credentials_have_metadata_fields_for_placeholder(self) -> None:
"""The returned credentials must have id, provider, type, and title so
manager.py can synthesise a valid CredentialsMetaInput placeholder."""
from backend.integrations.providers import ProviderName
creds = get_dry_run_credentials({"_dry_run_api_key": "sk-or-test"})
assert creds is not None
assert creds.id == "dry-run-platform"
assert creds.provider == ProviderName.OPEN_ROUTER
assert creds.type == "api_key"
assert creds.title is not None
# ---------------------------------------------------------------------------
# simulate_block input/output passthrough
# ---------------------------------------------------------------------------

View File

@@ -1,10 +1,13 @@
from datetime import datetime, timezone
from typing import cast
import pytest
from pytest_mock import MockerFixture
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
from backend.data.execution import ExecutionStatus
from backend.data.execution import ExecutionStatus, GraphExecutionWithNodes
from backend.data.model import User
from backend.executor.utils import add_graph_execution
from backend.util.mock import MockObject
@@ -473,6 +476,104 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
assert result2 == mock_graph_exec_2
# ============================================================================
# Regression test: RPC layer returns typed User model, not raw dict
# ============================================================================
@pytest.mark.asyncio
async def test_add_graph_execution_via_rpc_returns_typed_user(
mocker: MockerFixture,
):
"""
Regression test: `add_graph_execution` accesses `user.timezone` on the User
returned by `get_user_by_id`. This test verifies the downstream code path
completes without AttributeError when `get_user_by_id` returns a proper typed
User model. Note: the mock returns a User directly — _get_return deserialization
is not exercised here; see TestGetReturn in util/service_test.py for that.
"""
graph_id = "test-graph-id"
user_id = "test-user-id"
mock_graph = mocker.MagicMock()
mock_graph.version = 1
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "exec-id-rpc"
mock_graph_exec.node_executions = []
mock_graph_exec.status = ExecutionStatus.QUEUED
mock_graph_exec.graph_version = 1
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
mock_queue = mocker.AsyncMock()
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_validate.return_value = (mock_graph, [], {}, set())
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = (
False # prisma not connected: uses RPC path instead
)
# The RPC layer (_get_return) deserializes JSON dicts into typed Pydantic models.
# The mock simulates what add_graph_execution receives after that deserialization:
# a proper User model, not a raw dict.
mock_user = User(
id=user_id,
email="test@example.com",
name=None,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
stripe_customer_id=None,
top_up_config=None,
timezone="UTC",
)
mock_db_client = mocker.MagicMock()
mock_db_client.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_db_client.get_graph_settings = mocker.AsyncMock(
return_value=mocker.MagicMock(
human_in_the_loop_safe_mode=False, sensitive_action_safe_mode=False
)
)
mock_db_client.create_graph_execution = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_db_client.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_db_client.update_node_execution_status_batch = mocker.AsyncMock()
mock_workspace = mocker.MagicMock()
mock_workspace.id = "ws-id"
mock_db_client.get_or_create_workspace = mocker.AsyncMock(
return_value=mock_workspace
)
mock_db_client.increment_onboarding_runs = mocker.AsyncMock()
mocker.patch(
"backend.executor.utils.get_database_manager_async_client",
return_value=mock_db_client,
)
mocker.patch(
"backend.executor.utils.get_async_execution_queue", return_value=mock_queue
)
mocker.patch(
"backend.executor.utils.get_async_execution_event_bus",
return_value=mock_event_bus,
)
# Must not raise AttributeError: 'dict' object has no attribute 'timezone'
result = await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
)
assert result == mock_graph_exec
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================

View File

@@ -26,6 +26,7 @@ from typing import (
)
import httpx
import sentry_sdk
import uvicorn
from fastapi import FastAPI, Request, responses
from prisma.errors import DataError, UniqueViolationError
@@ -711,16 +712,16 @@ def get_service_client(
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
"""Validate and coerce the RPC result to the expected return type.
Falls back to the raw result with a warning if validation fails.
Falls back to the raw result with a warning and Sentry capture if validation fails.
"""
if expected_return:
try:
return expected_return.validate_python(result)
except Exception as e:
logger.warning(
"RPC return type validation failed, using raw result: %s",
type(e).__name__,
f"RPC return type validation failed for {type(e).__name__}: {e}"
)
sentry_sdk.capture_exception(e)
return result
return result

View File

@@ -1,13 +1,17 @@
import asyncio
import contextlib
import time
from datetime import datetime, timezone
from functools import cached_property
from typing import Any, Protocol, cast
from unittest.mock import Mock
import httpx
import pytest
from prisma.errors import DataError, UniqueViolationError
from pydantic import TypeAdapter
from backend.data.model import User
from backend.util.service import (
AppService,
AppServiceClient,
@@ -21,6 +25,10 @@ from backend.util.service import (
TEST_SERVICE_PORT = 8765
class _SupportsGetReturn(Protocol):
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any: ...
class ServiceTest(AppService):
def __init__(self):
super().__init__()
@@ -688,3 +696,46 @@ async def test_health_check_during_shutdown(test_service):
except (httpx.ConnectError, httpx.ConnectTimeout):
# Connection refused/timeout is also acceptable
pass
# ============================================================================
# Unit tests for DynamicClient._get_return
# ============================================================================
class TestGetReturn:
"""Direct unit tests for DynamicClient._get_return typed-return contract."""
def _make_client(self) -> _SupportsGetReturn:
return cast(_SupportsGetReturn, get_service_client(ServiceTestClient))
def test_valid_dict_is_deserialized_to_user_model(self):
"""TypeAdapter(User) + valid dict → User model returned with .timezone accessible.
User.model_config uses extra='ignore' so unknown fields (e.g. new columns added
in a newer database-manager deploy) are silently dropped instead of raising
ValidationError — making the RPC layer forward-compatible during rolling deploys.
"""
now = datetime.now(timezone.utc).isoformat()
valid_dict = {
"id": "user-id",
"email": "test@example.com",
"created_at": now,
"updated_at": now,
"unknown_future_field": "some_value", # simulates a new DB field during deploy
}
client = self._make_client()
adapter = TypeAdapter(User)
result = client._get_return(adapter, valid_dict)
assert isinstance(result, User)
assert result.timezone is not None
def test_invalid_dict_falls_back_to_raw_result(self):
"""TypeAdapter(User) + invalid dict (missing required fields) → fallback returns raw dict."""
invalid_dict = {"id": "user-id"} # missing email, created_at, updated_at
client = self._make_client()
adapter = TypeAdapter(User)
result = client._get_return(adapter, invalid_dict)
assert result == invalid_dict

View File

@@ -155,6 +155,7 @@ class WorkspaceManager:
path: Optional[str] = None,
mime_type: Optional[str] = None,
overwrite: bool = False,
metadata: Optional[dict] = None,
) -> WorkspaceFile:
"""
Write file to workspace.
@@ -168,6 +169,7 @@ class WorkspaceManager:
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
mime_type: MIME type (auto-detected if not provided)
overwrite: Whether to overwrite existing file at path
metadata: Optional metadata dict (e.g., origin tracking)
Returns:
Created WorkspaceFile instance
@@ -246,6 +248,7 @@ class WorkspaceManager:
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
metadata=metadata,
)
except UniqueViolationError:
if retries > 0:

View File

@@ -0,0 +1,43 @@
-- CreateTable
CREATE TABLE "PlatformCostLog" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT,
"graphExecId" TEXT,
"nodeExecId" TEXT,
"graphId" TEXT,
"nodeId" TEXT,
"blockId" TEXT,
"blockName" TEXT,
"provider" TEXT NOT NULL,
"credentialId" TEXT,
"costMicrodollars" BIGINT,
"inputTokens" INTEGER,
"outputTokens" INTEGER,
"dataSize" INTEGER,
"duration" DOUBLE PRECISION,
"model" TEXT,
"trackingType" TEXT,
"trackingAmount" DOUBLE PRECISION,
"metadata" JSONB,
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "PlatformCostLog_userId_createdAt_idx" ON "PlatformCostLog"("userId", "createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_provider_createdAt_idx" ON "PlatformCostLog"("provider", "createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
-- CreateIndex
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
-- AddForeignKey
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -75,6 +75,8 @@ model User {
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
PlatformCostLogs PlatformCostLog[]
// OAuth Provider relations
OAuthApplications OAuthApplication[]
OAuthAuthorizationCodes OAuthAuthorizationCode[]
@@ -815,6 +817,45 @@ model CreditRefundRequest {
@@index([userId, transactionKey])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////// Platform Cost Tracking TABLES //////////////
////////////////////////////////////////////////////////////
model PlatformCostLog {
id String @id @default(uuid())
createdAt DateTime @default(now())
userId String?
User User? @relation(fields: [userId], references: [id], onDelete: SetNull)
graphExecId String?
nodeExecId String?
graphId String?
nodeId String?
blockId String?
blockName String?
provider String
credentialId String?
// Cost in microdollars (1 USD = 1,000,000). Null if unknown.
costMicrodollars BigInt?
inputTokens Int?
outputTokens Int?
dataSize Int? // bytes
duration Float? // seconds
model String?
trackingType String? // e.g. "cost_usd", "tokens", "characters", "items", "per_run", "sandbox_seconds", "walltime_seconds"
trackingAmount Float? // Amount in the unit implied by trackingType
metadata Json?
@@index([userId, createdAt])
@@index([provider, createdAt])
@@index([createdAt])
@@index([graphExecId])
@@index([provider, trackingType])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// Store TABLES ///////////////////////////

View File

@@ -66,6 +66,29 @@ describe("useOnboardingWizardStore", () => {
"no tests",
]);
});
it("ignores new selections when at the max limit", () => {
useOnboardingWizardStore.getState().togglePainPoint("a");
useOnboardingWizardStore.getState().togglePainPoint("b");
useOnboardingWizardStore.getState().togglePainPoint("c");
useOnboardingWizardStore.getState().togglePainPoint("d");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
"a",
"b",
"c",
]);
});
it("still allows deselecting when at the max limit", () => {
useOnboardingWizardStore.getState().togglePainPoint("a");
useOnboardingWizardStore.getState().togglePainPoint("b");
useOnboardingWizardStore.getState().togglePainPoint("c");
useOnboardingWizardStore.getState().togglePainPoint("b");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
"a",
"c",
]);
});
});
describe("setOtherPainPoint", () => {

View File

@@ -7,9 +7,9 @@ export function ProgressBar({ currentStep, totalSteps }: Props) {
const percent = (currentStep / totalSteps) * 100;
return (
<div className="absolute left-0 top-0 h-[0.625rem] w-full bg-neutral-300">
<div className="absolute left-0 top-0 h-[3px] w-full bg-neutral-200">
<div
className="h-full bg-purple-400 shadow-[0_0_4px_2px_rgba(168,85,247,0.5)] transition-all duration-500 ease-out"
className="h-full bg-purple-400 transition-all duration-500 ease-out"
style={{ width: `${percent}%` }}
/>
</div>

View File

@@ -2,6 +2,7 @@
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { Check } from "@phosphor-icons/react";
interface Props {
icon: React.ReactNode;
@@ -24,13 +25,18 @@ export function SelectableCard({
onClick={onClick}
aria-pressed={selected}
className={cn(
"flex h-[9rem] w-[10.375rem] shrink-0 flex-col items-center justify-center gap-3 rounded-xl border-2 bg-white px-6 py-5 transition-all hover:shadow-sm md:shrink lg:gap-2 lg:px-10 lg:py-8",
"relative flex h-[9rem] w-[10.375rem] shrink-0 flex-col items-center justify-center gap-3 rounded-xl border-2 bg-white px-6 py-5 transition-all hover:shadow-sm md:shrink lg:gap-2 lg:px-10 lg:py-8",
className,
selected
? "border-purple-500 bg-purple-50 shadow-sm"
: "border-transparent",
)}
>
{selected && (
<span className="absolute right-2 top-2 flex h-5 w-5 items-center justify-center rounded-full bg-purple-500">
<Check size={12} weight="bold" className="text-white" />
</span>
)}
<Text
variant="lead"
as="span"

View File

@@ -3,6 +3,7 @@
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { ReactNode } from "react";
import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
@@ -73,6 +74,8 @@ export function PainPointsStep() {
togglePainPoint,
setOtherPainPoint,
hasSomethingElse,
atLimit,
shaking,
canContinue,
handleLaunch,
} = usePainPointsStep();
@@ -90,7 +93,7 @@ export function PainPointsStep() {
What&apos;s eating your time?
</Text>
<Text variant="lead" className="!text-zinc-500">
Pick the tasks you&apos;d love to hand off to Autopilot
Pick the tasks you&apos;d love to hand off to AutoPilot
</Text>
</div>
@@ -107,11 +110,22 @@ export function PainPointsStep() {
/>
))}
</div>
{!hasSomethingElse ? (
<Text variant="small" className="!text-zinc-500">
Pick as many as you want you can always change later
</Text>
) : null}
<Text
variant="small"
className={cn(
"transition-colors",
atLimit && canContinue ? "!text-green-600" : "!text-zinc-500",
shaking && "animate-shake",
)}
>
{shaking
? "You've picked 3 — tap one to swap it out"
: atLimit && canContinue
? "3 selected — you're all set!"
: atLimit && hasSomethingElse
? "Tell us what else takes up your time"
: "Pick up to 3 to start — AutoPilot can help with anything else later"}
</Text>
</div>
{hasSomethingElse && (
@@ -133,7 +147,7 @@ export function PainPointsStep() {
disabled={!canContinue}
className="w-full max-w-xs"
>
Launch Autopilot
Launch AutoPilot
</Button>
</div>
</FadeIn>

View File

@@ -8,6 +8,7 @@ import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
import { SelectableCard } from "../components/SelectableCard";
import { useOnboardingWizardStore } from "../store";
import { Emoji } from "@/components/atoms/Emoji/Emoji";
import { useEffect, useRef } from "react";
const IMG_SIZE = 42;
@@ -57,12 +58,26 @@ export function RoleStep() {
const setRole = useOnboardingWizardStore((s) => s.setRole);
const setOtherRole = useOnboardingWizardStore((s) => s.setOtherRole);
const nextStep = useOnboardingWizardStore((s) => s.nextStep);
const autoAdvanceTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const isOther = role === "Other";
const canContinue = role && (!isOther || otherRole.trim());
function handleContinue() {
if (canContinue) {
useEffect(() => {
return () => {
if (autoAdvanceTimer.current) clearTimeout(autoAdvanceTimer.current);
};
}, []);
function handleRoleSelect(id: string) {
if (autoAdvanceTimer.current) clearTimeout(autoAdvanceTimer.current);
setRole(id);
if (id !== "Other") {
autoAdvanceTimer.current = setTimeout(nextStep, 350);
}
}
function handleOtherContinue() {
if (otherRole.trim()) {
nextStep();
}
}
@@ -78,7 +93,7 @@ export function RoleStep() {
What best describes you, {name}?
</Text>
<Text variant="lead" className="!text-zinc-500">
Autopilot will tailor automations to your world
So AutoPilot knows how to help you best
</Text>
</div>
@@ -89,33 +104,35 @@ export function RoleStep() {
icon={r.icon}
label={r.label}
selected={role === r.id}
onClick={() => setRole(r.id)}
onClick={() => handleRoleSelect(r.id)}
className="p-8"
/>
))}
</div>
{isOther && (
<div className="-mb-5 w-full px-8 md:px-0">
<Input
id="other-role"
label="Other role"
hideLabel
placeholder="Describe your role..."
value={otherRole}
onChange={(e) => setOtherRole(e.target.value)}
autoFocus
/>
</div>
)}
<>
<div className="-mb-5 w-full px-8 md:px-0">
<Input
id="other-role"
label="Other role"
hideLabel
placeholder="Describe your role..."
value={otherRole}
onChange={(e) => setOtherRole(e.target.value)}
autoFocus
/>
</div>
<Button
onClick={handleContinue}
disabled={!canContinue}
className="w-full max-w-xs"
>
Continue
</Button>
<Button
onClick={handleOtherContinue}
disabled={!otherRole.trim()}
className="w-full max-w-xs"
>
Continue
</Button>
</>
)}
</div>
</FadeIn>
);

View File

@@ -4,13 +4,6 @@ import { AutoGPTLogo } from "@/components/atoms/AutoGPTLogo/AutoGPTLogo";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Text } from "@/components/atoms/Text/Text";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { Question } from "@phosphor-icons/react";
import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
import { useOnboardingWizardStore } from "../store";
@@ -40,36 +33,16 @@ export function WelcomeStep() {
<Text variant="h3">Welcome to AutoGPT</Text>
<Text variant="lead" as="span" className="!text-zinc-500">
Let&apos;s personalize your experience so{" "}
<span className="relative mr-3 inline-block bg-gradient-to-r from-purple-500 to-indigo-500 bg-clip-text text-transparent">
Autopilot
<span className="absolute -right-4 top-0">
<TooltipProvider delayDuration={400}>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-label="What is Autopilot?"
className="inline-flex text-purple-500"
>
<Question size={14} />
</button>
</TooltipTrigger>
<TooltipContent>
Autopilot is AutoGPT&apos;s AI assistant that watches your
connected apps, spots repetitive tasks you do every day
and runs them for you automatically.
</TooltipContent>
</Tooltip>
</TooltipProvider>
</span>
<span className="bg-gradient-to-r from-purple-500 to-indigo-500 bg-clip-text text-transparent">
AutoPilot
</span>{" "}
can start saving you time right away
can start saving you time
</Text>
</div>
<Input
id="first-name"
label="Your first name"
label="What should I call you?"
placeholder="e.g. John"
value={name}
onChange={(e) => setName(e.target.value)}

View File

@@ -0,0 +1,154 @@
import {
render,
screen,
fireEvent,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { useOnboardingWizardStore } from "../../store";
import { PainPointsStep } from "../PainPointsStep";
vi.mock("@/components/atoms/Emoji/Emoji", () => ({
Emoji: ({ text }: { text: string }) => <span>{text}</span>,
}));
vi.mock("@/components/atoms/FadeIn/FadeIn", () => ({
FadeIn: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
function getCard(name: RegExp) {
return screen.getByRole("button", { name });
}
function clickCard(name: RegExp) {
fireEvent.click(getCard(name));
}
function getLaunchButton() {
return screen.getByRole("button", { name: /launch autopilot/i });
}
afterEach(cleanup);
beforeEach(() => {
useOnboardingWizardStore.getState().reset();
useOnboardingWizardStore.getState().setName("Alice");
useOnboardingWizardStore.getState().setRole("Founder/CEO");
useOnboardingWizardStore.getState().goToStep(3);
});
describe("PainPointsStep", () => {
test("renders all pain point cards", () => {
render(<PainPointsStep />);
expect(getCard(/finding leads/i)).toBeDefined();
expect(getCard(/email & outreach/i)).toBeDefined();
expect(getCard(/reports & data/i)).toBeDefined();
expect(getCard(/customer support/i)).toBeDefined();
expect(getCard(/social media/i)).toBeDefined();
expect(getCard(/something else/i)).toBeDefined();
});
test("shows default helper text", () => {
render(<PainPointsStep />);
expect(
screen.getAllByText(/pick up to 3 to start/i).length,
).toBeGreaterThan(0);
});
test("selecting a card marks it as pressed", () => {
render(<PainPointsStep />);
clickCard(/finding leads/i);
expect(getCard(/finding leads/i).getAttribute("aria-pressed")).toBe("true");
});
test("launch button is disabled when nothing is selected", () => {
render(<PainPointsStep />);
expect(getLaunchButton().hasAttribute("disabled")).toBe(true);
});
test("launch button is enabled after selecting a pain point", () => {
render(<PainPointsStep />);
clickCard(/finding leads/i);
expect(getLaunchButton().hasAttribute("disabled")).toBe(false);
});
test("shows success text when 3 items are selected", () => {
render(<PainPointsStep />);
clickCard(/finding leads/i);
clickCard(/email & outreach/i);
clickCard(/reports & data/i);
expect(screen.getAllByText(/3 selected/i).length).toBeGreaterThan(0);
});
test("does not select a 4th item when at the limit", () => {
render(<PainPointsStep />);
clickCard(/finding leads/i);
clickCard(/email & outreach/i);
clickCard(/reports & data/i);
clickCard(/customer support/i);
expect(getCard(/customer support/i).getAttribute("aria-pressed")).toBe(
"false",
);
});
test("can deselect when at the limit and select a different one", () => {
render(<PainPointsStep />);
clickCard(/finding leads/i);
clickCard(/email & outreach/i);
clickCard(/reports & data/i);
clickCard(/finding leads/i);
expect(getCard(/finding leads/i).getAttribute("aria-pressed")).toBe(
"false",
);
clickCard(/customer support/i);
expect(getCard(/customer support/i).getAttribute("aria-pressed")).toBe(
"true",
);
});
test("shows input when 'Something else' is selected", () => {
render(<PainPointsStep />);
clickCard(/something else/i);
expect(
screen.getByPlaceholderText(/what else takes up your time/i),
).toBeDefined();
});
test("launch button is disabled when 'Something else' selected but input empty", () => {
render(<PainPointsStep />);
clickCard(/something else/i);
expect(getLaunchButton().hasAttribute("disabled")).toBe(true);
});
test("launch button is enabled when 'Something else' selected and input filled", () => {
render(<PainPointsStep />);
clickCard(/something else/i);
fireEvent.change(
screen.getByPlaceholderText(/what else takes up your time/i),
{ target: { value: "Manual invoicing" } },
);
expect(getLaunchButton().hasAttribute("disabled")).toBe(false);
});
});

View File

@@ -0,0 +1,123 @@
import {
render,
screen,
fireEvent,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { useOnboardingWizardStore } from "../../store";
import { RoleStep } from "../RoleStep";
vi.mock("@/components/atoms/Emoji/Emoji", () => ({
Emoji: ({ text }: { text: string }) => <span>{text}</span>,
}));
vi.mock("@/components/atoms/FadeIn/FadeIn", () => ({
FadeIn: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
vi.useRealTimers();
});
beforeEach(() => {
vi.useFakeTimers();
useOnboardingWizardStore.getState().reset();
useOnboardingWizardStore.getState().setName("Alice");
useOnboardingWizardStore.getState().goToStep(2);
});
describe("RoleStep", () => {
test("renders all role cards", () => {
render(<RoleStep />);
expect(screen.getByText("Founder / CEO")).toBeDefined();
expect(screen.getByText("Operations")).toBeDefined();
expect(screen.getByText("Sales / BD")).toBeDefined();
expect(screen.getByText("Marketing")).toBeDefined();
expect(screen.getByText("Product / PM")).toBeDefined();
expect(screen.getByText("Engineering")).toBeDefined();
expect(screen.getByText("HR / People")).toBeDefined();
expect(screen.getByText("Other")).toBeDefined();
});
test("displays the user name in the heading", () => {
render(<RoleStep />);
expect(
screen.getAllByText(/what best describes you, alice/i).length,
).toBeGreaterThan(0);
});
test("selecting a non-Other role auto-advances after delay", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /engineering/i }));
expect(useOnboardingWizardStore.getState().role).toBe("Engineering");
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
vi.advanceTimersByTime(350);
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
});
test("selecting 'Other' does not auto-advance", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
vi.advanceTimersByTime(500);
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
});
test("selecting 'Other' shows text input and Continue button", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
expect(screen.getByPlaceholderText(/describe your role/i)).toBeDefined();
expect(screen.getByRole("button", { name: /continue/i })).toBeDefined();
});
test("Continue button is disabled when Other input is empty", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
const continueBtn = screen.getByRole("button", { name: /continue/i });
expect(continueBtn.hasAttribute("disabled")).toBe(true);
});
test("Continue button advances when Other role text is filled", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
fireEvent.change(screen.getByPlaceholderText(/describe your role/i), {
target: { value: "Designer" },
});
const continueBtn = screen.getByRole("button", { name: /continue/i });
expect(continueBtn.hasAttribute("disabled")).toBe(false);
fireEvent.click(continueBtn);
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
});
test("switching from Other to a regular role cancels Other and auto-advances", () => {
render(<RoleStep />);
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
expect(screen.getByPlaceholderText(/describe your role/i)).toBeDefined();
fireEvent.click(screen.getByRole("button", { name: /marketing/i }));
expect(useOnboardingWizardStore.getState().role).toBe("Marketing");
vi.advanceTimersByTime(350);
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
});
});

View File

@@ -1,4 +1,5 @@
import { useOnboardingWizardStore } from "../store";
import { useEffect, useRef, useState } from "react";
import { MAX_PAIN_POINT_SELECTIONS, useOnboardingWizardStore } from "../store";
const ROLE_TOP_PICKS: Record<string, string[]> = {
"Founder/CEO": [
@@ -23,18 +24,38 @@ export function usePainPointsStep() {
const role = useOnboardingWizardStore((s) => s.role);
const painPoints = useOnboardingWizardStore((s) => s.painPoints);
const otherPainPoint = useOnboardingWizardStore((s) => s.otherPainPoint);
const togglePainPoint = useOnboardingWizardStore((s) => s.togglePainPoint);
const storeToggle = useOnboardingWizardStore((s) => s.togglePainPoint);
const setOtherPainPoint = useOnboardingWizardStore(
(s) => s.setOtherPainPoint,
);
const nextStep = useOnboardingWizardStore((s) => s.nextStep);
const [shaking, setShaking] = useState(false);
const shakeTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => {
return () => {
if (shakeTimer.current) clearTimeout(shakeTimer.current);
};
}, []);
const topIDs = getTopPickIDs(role);
const hasSomethingElse = painPoints.includes("Something else");
const atLimit = painPoints.length >= MAX_PAIN_POINT_SELECTIONS;
const canContinue =
painPoints.length > 0 &&
(!hasSomethingElse || Boolean(otherPainPoint.trim()));
function togglePainPoint(id: string) {
const alreadySelected = painPoints.includes(id);
if (!alreadySelected && atLimit) {
if (shakeTimer.current) clearTimeout(shakeTimer.current);
setShaking(true);
shakeTimer.current = setTimeout(() => setShaking(false), 600);
return;
}
storeToggle(id);
}
function handleLaunch() {
if (canContinue) {
nextStep();
@@ -48,6 +69,8 @@ export function usePainPointsStep() {
togglePainPoint,
setOtherPainPoint,
hasSomethingElse,
atLimit,
shaking,
canContinue,
handleLaunch,
};

View File

@@ -1,5 +1,6 @@
import { create } from "zustand";
export const MAX_PAIN_POINT_SELECTIONS = 3;
export type Step = 1 | 2 | 3 | 4;
interface OnboardingWizardState {
@@ -40,6 +41,8 @@ export const useOnboardingWizardStore = create<OnboardingWizardState>(
togglePainPoint(painPoint) {
set((state) => {
const exists = state.painPoints.includes(painPoint);
if (!exists && state.painPoints.length >= MAX_PAIN_POINT_SELECTIONS)
return state;
return {
painPoints: exists
? state.painPoints.filter((p) => p !== painPoint)

View File

@@ -1,6 +1,12 @@
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { Gauge } from "@phosphor-icons/react/dist/ssr";
import {
Users,
CurrencyDollar,
MagnifyingGlass,
Gauge,
Receipt,
FileText,
} from "@phosphor-icons/react/dist/ssr";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -15,18 +21,23 @@ const sidebarLinkGroups = [
{
text: "User Spending",
href: "/admin/spending",
icon: <DollarSign className="h-6 w-6" />,
icon: <CurrencyDollar className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",
icon: <UserSearch className="h-6 w-6" />,
icon: <MagnifyingGlass className="h-6 w-6" />,
},
{
text: "Rate Limits",
href: "/admin/rate-limits",
icon: <Gauge className="h-6 w-6" />,
},
{
text: "Platform Costs",
href: "/admin/platform-costs",
icon: <Receipt className="h-6 w-6" />,
},
{
text: "Execution Analytics",
href: "/admin/execution-analytics",

View File

@@ -0,0 +1,429 @@
import {
render,
screen,
cleanup,
waitFor,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { PlatformCostContent } from "../components/PlatformCostContent";
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
import type { PlatformCostLogsResponse } from "@/app/api/__generated__/models/platformCostLogsResponse";
// Mock the generated Orval hooks so tests don't hit the network
const mockUseGetDashboard = vi.fn();
const mockUseGetLogs = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
useGetV2GetPlatformCostDashboard: (...args: unknown[]) =>
mockUseGetDashboard(...args),
useGetV2GetPlatformCostLogs: (...args: unknown[]) => mockUseGetLogs(...args),
}));
afterEach(() => {
cleanup();
mockUseGetDashboard.mockReset();
mockUseGetLogs.mockReset();
});
const emptyDashboard: PlatformCostDashboard = {
total_cost_microdollars: 0,
total_requests: 0,
total_users: 0,
by_provider: [],
by_user: [],
};
const emptyLogs: PlatformCostLogsResponse = {
logs: [],
pagination: {
current_page: 1,
page_size: 50,
total_items: 0,
total_pages: 0,
},
};
const dashboardWithData: PlatformCostDashboard = {
total_cost_microdollars: 5_000_000,
total_requests: 100,
total_users: 5,
by_provider: [
{
provider: "openai",
tracking_type: "tokens",
total_cost_microdollars: 3_000_000,
total_input_tokens: 50000,
total_output_tokens: 20000,
total_duration_seconds: 0,
request_count: 60,
},
{
provider: "google_maps",
tracking_type: "per_run",
total_cost_microdollars: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_duration_seconds: 0,
request_count: 40,
},
],
by_user: [
{
user_id: "user-1",
email: "alice@example.com",
total_cost_microdollars: 3_000_000,
total_input_tokens: 50000,
total_output_tokens: 20000,
request_count: 60,
},
],
};
const logsWithData: PlatformCostLogsResponse = {
logs: [
{
id: "log-1",
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
user_id: "user-1",
email: "alice@example.com",
graph_exec_id: "gx-123",
node_exec_id: "nx-456",
block_name: "LLMBlock",
provider: "openai",
tracking_type: "tokens",
cost_microdollars: 5000,
input_tokens: 100,
output_tokens: 50,
duration: 1.5,
model: "gpt-4",
},
],
pagination: {
current_page: 1,
page_size: 50,
total_items: 1,
total_pages: 1,
},
};
function renderComponent(searchParams = {}) {
return render(<PlatformCostContent searchParams={searchParams} />);
}
describe("PlatformCostContent", () => {
it("shows loading state initially", () => {
mockUseGetDashboard.mockReturnValue({ data: undefined, isLoading: true });
mockUseGetLogs.mockReturnValue({ data: undefined, isLoading: true });
renderComponent();
// Loading state renders Skeleton placeholders (animate-pulse divs) instead of content
expect(screen.queryByText("Loading...")).toBeNull();
// Summary cards and table content are not yet shown
expect(screen.queryByText("Known Cost")).toBeNull();
});
it("renders empty dashboard", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: emptyLogs,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
const zeroCostItems = screen.getAllByText("$0.0000");
expect(zeroCostItems.length).toBe(2);
expect(screen.getByText("No cost data yet")).toBeDefined();
});
it("renders dashboard with provider data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("$5.0000")).toBeDefined();
expect(screen.getByText("100")).toBeDefined();
expect(screen.getByText("5")).toBeDefined();
expect(screen.getByText("openai")).toBeDefined();
expect(screen.getByText("google_maps")).toBeDefined();
});
it("renders tracking type badges", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("tokens")).toBeDefined();
expect(screen.getByText("per_run")).toBeDefined();
});
it("shows error state on fetch failure", async () => {
mockUseGetDashboard.mockReturnValue({
data: undefined,
isLoading: false,
error: new Error("Network error"),
});
mockUseGetLogs.mockReturnValue({
data: undefined,
isLoading: false,
error: new Error("Network error"),
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Network error")).toBeDefined();
});
it("renders tab buttons", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("By Provider")).toBeDefined();
expect(screen.getByText("By User")).toBeDefined();
expect(screen.getByText("Raw Logs")).toBeDefined();
});
it("renders summary cards with correct labels", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
expect(screen.getByText("Total Requests")).toBeDefined();
expect(screen.getByText("Active Users")).toBeDefined();
});
it("renders filter inputs", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Start Date")).toBeDefined();
expect(screen.getByText("End Date")).toBeDefined();
expect(screen.getAllByText(/Provider/i).length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("User ID")).toBeDefined();
expect(screen.getByText("Apply")).toBeDefined();
});
it("renders by-user tab when specified", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("alice@example.com")).toBeDefined();
});
it("renders logs tab when specified", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("LLMBlock")).toBeDefined();
expect(screen.getByText("gpt-4")).toBeDefined();
});
it("renders no logs message when empty", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("No logs found")).toBeDefined();
});
it("shows pagination when multiple pages", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
const multiPageLogs: PlatformCostLogsResponse = {
logs: logsWithData.logs,
pagination: {
current_page: 1,
page_size: 50,
total_items: 200,
total_pages: 4,
},
};
mockUseGetLogs.mockReturnValue({
data: multiPageLogs,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Previous")).toBeDefined();
expect(screen.getByText("Next")).toBeDefined();
expect(screen.getByText(/Page 1 of 4/)).toBeDefined();
});
it("renders user table with unknown email", async () => {
const dashWithNullEmail: PlatformCostDashboard = {
...dashboardWithData,
by_user: [
{
user_id: "user-2",
email: null,
total_cost_microdollars: 1000,
total_input_tokens: 100,
total_output_tokens: 50,
request_count: 5,
},
],
};
mockUseGetDashboard.mockReturnValue({
data: dashWithNullEmail,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Unknown")).toBeDefined();
});
it("by-user tab content visible when tab=by-user param set", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("alice@example.com")).toBeDefined();
// overview tab content should not be visible
expect(screen.queryByText("openai")).toBeNull();
});
it("logs tab content visible when tab=logs param set", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("LLMBlock")).toBeDefined();
expect(screen.getByText("gpt-4")).toBeDefined();
});
it("renders log with null user as dash", async () => {
const logWithNullUser: PlatformCostLogsResponse = {
logs: [
{
id: "log-2",
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
user_id: null,
email: null,
graph_exec_id: null,
node_exec_id: null,
block_name: "copilot:SDK",
provider: "anthropic",
tracking_type: "cost_usd",
cost_microdollars: 15000,
input_tokens: null,
output_tokens: null,
duration: null,
model: "claude-opus-4-20250514",
},
],
pagination: {
current_page: 1,
page_size: 50,
total_items: 1,
total_pages: 1,
},
};
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logWithNullUser,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("copilot:SDK")).toBeDefined();
expect(screen.getByText("anthropic")).toBeDefined();
// null email + null user_id renders as "-" in the User column; multiple
// other cells (tokens, duration, session) also render "-", so use
// getAllByText to avoid the single-match constraint.
expect(screen.getAllByText("-").length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,87 @@
import { describe, expect, it, vi } from "vitest";
const mockGetDashboard = vi.fn();
const mockGetLogs = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
getV2GetPlatformCostDashboard: (...args: unknown[]) =>
mockGetDashboard(...args),
getV2GetPlatformCostLogs: (...args: unknown[]) => mockGetLogs(...args),
}));
import { getPlatformCostDashboard, getPlatformCostLogs } from "../actions";
describe("getPlatformCostDashboard", () => {
it("returns data on success", async () => {
const mockData = { total_cost_microdollars: 1000, total_requests: 5 };
mockGetDashboard.mockResolvedValue({ status: 200, data: mockData });
const result = await getPlatformCostDashboard();
expect(result).toEqual(mockData);
});
it("returns undefined on non-200", async () => {
mockGetDashboard.mockResolvedValue({ status: 401 });
const result = await getPlatformCostDashboard();
expect(result).toBeUndefined();
});
it("passes filter params to API", async () => {
mockGetDashboard.mockReset();
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
await getPlatformCostDashboard({
start: "2026-01-01T00:00:00",
end: "2026-06-01T00:00:00",
provider: "openai",
user_id: "user-1",
});
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
const params = mockGetDashboard.mock.calls[0][0];
expect(params.start).toBe("2026-01-01T00:00:00");
expect(params.end).toBe("2026-06-01T00:00:00");
expect(params.provider).toBe("openai");
expect(params.user_id).toBe("user-1");
});
it("passes undefined for empty filter strings", async () => {
mockGetDashboard.mockReset();
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
await getPlatformCostDashboard({
start: "",
provider: "",
user_id: "",
});
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
const params = mockGetDashboard.mock.calls[0][0];
expect(params.start).toBeUndefined();
expect(params.provider).toBeUndefined();
expect(params.user_id).toBeUndefined();
});
});
describe("getPlatformCostLogs", () => {
it("returns data on success", async () => {
const mockData = { logs: [], pagination: { current_page: 1 } };
mockGetLogs.mockResolvedValue({ status: 200, data: mockData });
const result = await getPlatformCostLogs();
expect(result).toEqual(mockData);
});
it("passes page and page_size", async () => {
mockGetLogs.mockReset();
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
await getPlatformCostLogs({ page: 3, page_size: 25 });
expect(mockGetLogs).toHaveBeenCalledTimes(1);
const params = mockGetLogs.mock.calls[0][0];
expect(params.page).toBe(3);
expect(params.page_size).toBe(25);
});
it("passes start date string through to API", async () => {
mockGetLogs.mockReset();
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
await getPlatformCostLogs({ start: "2026-03-01T00:00:00" });
expect(mockGetLogs).toHaveBeenCalledTimes(1);
const params = mockGetLogs.mock.calls[0][0];
expect(params.start).toBe("2026-03-01T00:00:00");
});
});

View File

@@ -0,0 +1,300 @@
import { describe, expect, it } from "vitest";
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
import {
toDateOrUndefined,
formatMicrodollars,
formatTokens,
formatDuration,
estimateCostForRow,
trackingValue,
toLocalInput,
toUtcIso,
} from "../helpers";
function makeRow(overrides: Partial<ProviderCostSummary>): ProviderCostSummary {
return {
provider: "openai",
tracking_type: null,
total_cost_microdollars: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_duration_seconds: 0,
request_count: 0,
...overrides,
};
}
describe("toDateOrUndefined", () => {
it("returns undefined for empty string", () => {
expect(toDateOrUndefined("")).toBeUndefined();
});
it("returns undefined for undefined", () => {
expect(toDateOrUndefined(undefined)).toBeUndefined();
});
it("returns undefined for invalid date string", () => {
expect(toDateOrUndefined("not-a-date")).toBeUndefined();
});
it("returns a Date for a valid ISO string", () => {
const result = toDateOrUndefined("2026-01-15T00:00:00Z");
expect(result).toBeInstanceOf(Date);
expect(result!.toISOString()).toBe("2026-01-15T00:00:00.000Z");
});
});
describe("formatMicrodollars", () => {
it("formats zero", () => {
expect(formatMicrodollars(0)).toBe("$0.0000");
});
it("formats a small amount", () => {
expect(formatMicrodollars(50_000)).toBe("$0.0500");
});
it("formats one dollar", () => {
expect(formatMicrodollars(1_000_000)).toBe("$1.0000");
});
});
describe("formatTokens", () => {
it("formats small numbers as-is", () => {
expect(formatTokens(500)).toBe("500");
});
it("formats thousands with K suffix", () => {
expect(formatTokens(1_500)).toBe("1.5K");
});
it("formats millions with M suffix", () => {
expect(formatTokens(2_500_000)).toBe("2.5M");
});
});
describe("formatDuration", () => {
it("formats seconds", () => {
expect(formatDuration(30)).toBe("30.0s");
});
it("formats minutes", () => {
expect(formatDuration(90)).toBe("1.5m");
});
it("formats hours", () => {
expect(formatDuration(5400)).toBe("1.5h");
});
});
describe("estimateCostForRow", () => {
it("returns microdollars directly for cost_usd tracking", () => {
const row = makeRow({
tracking_type: "cost_usd",
total_cost_microdollars: 500_000,
});
expect(estimateCostForRow(row, {})).toBe(500_000);
});
it("returns reported cost for token tracking when cost > 0", () => {
const row = makeRow({
tracking_type: "tokens",
total_cost_microdollars: 100_000,
total_input_tokens: 1000,
total_output_tokens: 500,
});
expect(estimateCostForRow(row, {})).toBe(100_000);
});
it("estimates cost from default rate for token tracking with zero cost", () => {
const row = makeRow({
provider: "openai",
tracking_type: "tokens",
total_cost_microdollars: 0,
total_input_tokens: 500,
total_output_tokens: 500,
});
// 1000 tokens / 1000 * 0.005 USD * 1_000_000 = 5000
expect(estimateCostForRow(row, {})).toBe(5000);
});
it("returns null for unknown token provider with zero cost", () => {
const row = makeRow({
provider: "unknown_provider",
tracking_type: "tokens",
total_cost_microdollars: 0,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("uses per-run override when provided", () => {
const row = makeRow({
provider: "google_maps",
tracking_type: "per_run",
request_count: 10,
});
// override = 0.05 * 10 * 1_000_000 = 500_000
expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe(
500_000,
);
});
it("uses default per-run cost when no override", () => {
const row = makeRow({
provider: "google_maps",
tracking_type: null,
request_count: 5,
});
// 0.032 * 5 * 1_000_000 = 160_000
expect(estimateCostForRow(row, {})).toBe(160_000);
});
it("returns null for unknown per_run provider", () => {
const row = makeRow({
provider: "totally_unknown",
tracking_type: "per_run",
request_count: 3,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("returns null for duration tracking with no rate and no cost", () => {
const row = makeRow({
provider: "openai",
tracking_type: "duration_seconds",
total_cost_microdollars: 0,
total_duration_seconds: 100,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("estimates cost from default rate for characters tracking", () => {
const row = makeRow({
provider: "elevenlabs",
tracking_type: "characters",
total_cost_microdollars: 0,
total_tracking_amount: 2000,
});
// 2000 chars / 1000 * 0.18 USD * 1_000_000 = 360_000
expect(estimateCostForRow(row, {})).toBe(360_000);
});
it("estimates cost from default rate for items tracking", () => {
const row = makeRow({
provider: "apollo",
tracking_type: "items",
total_cost_microdollars: 0,
total_tracking_amount: 50,
});
// 50 * 0.02 * 1_000_000 = 1_000_000
expect(estimateCostForRow(row, {})).toBe(1_000_000);
});
it("estimates cost from default rate for duration tracking", () => {
const row = makeRow({
provider: "e2b",
tracking_type: "sandbox_seconds",
total_cost_microdollars: 0,
total_duration_seconds: 1_000_000,
});
// 1_000_000 * 0.000014 * 1_000_000 = 14_000_000
expect(estimateCostForRow(row, {})).toBe(14_000_000);
});
});
describe("trackingValue", () => {
it("returns formatted microdollars for cost_usd", () => {
const row = makeRow({
tracking_type: "cost_usd",
total_cost_microdollars: 1_000_000,
});
expect(trackingValue(row)).toBe("$1.0000");
});
it("returns formatted token count for tokens", () => {
const row = makeRow({
tracking_type: "tokens",
total_input_tokens: 500,
total_output_tokens: 500,
});
expect(trackingValue(row)).toBe("1.0K tokens");
});
it("returns formatted duration for sandbox_seconds", () => {
const row = makeRow({
tracking_type: "sandbox_seconds",
total_duration_seconds: 120,
});
expect(trackingValue(row)).toBe("2.0m");
});
it("returns run count for per_run (default tracking)", () => {
const row = makeRow({
tracking_type: null,
request_count: 42,
});
expect(trackingValue(row)).toBe("42 runs");
});
it("returns formatted character count for characters tracking", () => {
const row = makeRow({
tracking_type: "characters",
total_tracking_amount: 2500,
});
expect(trackingValue(row)).toBe("2.5K chars");
});
it("returns formatted item count for items tracking", () => {
const row = makeRow({
tracking_type: "items",
total_tracking_amount: 1234,
});
expect(trackingValue(row)).toBe("1,234 items");
});
it("returns formatted duration for sandbox_seconds", () => {
const row = makeRow({
tracking_type: "sandbox_seconds",
total_duration_seconds: 7200,
});
expect(trackingValue(row)).toBe("2.0h");
});
it("returns formatted duration for walltime_seconds", () => {
const row = makeRow({
tracking_type: "walltime_seconds",
total_duration_seconds: 45,
});
expect(trackingValue(row)).toBe("45.0s");
});
});
describe("toLocalInput", () => {
it("returns empty string for empty input", () => {
expect(toLocalInput("")).toBe("");
});
it("returns empty string for invalid ISO", () => {
expect(toLocalInput("not-a-date")).toBe("");
});
it("converts UTC ISO to local datetime-local format", () => {
const result = toLocalInput("2026-01-15T12:30:00Z");
// Format should be YYYY-MM-DDTHH:mm
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$/);
});
});
describe("toUtcIso", () => {
it("returns empty string for empty input", () => {
expect(toUtcIso("")).toBe("");
});
it("returns empty string for invalid local time", () => {
expect(toUtcIso("not-a-date")).toBe("");
});
it("converts local datetime-local to ISO string", () => {
const result = toUtcIso("2026-01-15T12:30");
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/);
});
});

View File

@@ -0,0 +1,45 @@
import {
getV2GetPlatformCostDashboard,
getV2GetPlatformCostLogs,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
// Backend expects ISO datetime strings. The generated client's URL builder
// calls .toString() on values, which for Date objects produces the human
// "Tue Mar 31 2026 22:00:00 GMT+0000 (Coordinated Universal Time)" format
// that FastAPI rejects with 422. We already pass UTC ISO from the URL, so
// forward the raw strings through the `as unknown as Date` cast to match
// the generated typing without triggering Date.toString().
export async function getPlatformCostDashboard(params?: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
}) {
const response = await getV2GetPlatformCostDashboard({
start: (params?.start || undefined) as unknown as Date | undefined,
end: (params?.end || undefined) as unknown as Date | undefined,
provider: params?.provider || undefined,
user_id: params?.user_id || undefined,
});
return okData(response);
}
export async function getPlatformCostLogs(params?: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: number;
page_size?: number;
}) {
const response = await getV2GetPlatformCostLogs({
start: (params?.start || undefined) as unknown as Date | undefined,
end: (params?.end || undefined) as unknown as Date | undefined,
provider: params?.provider || undefined,
user_id: params?.user_id || undefined,
page: params?.page,
page_size: params?.page_size,
});
return okData(response);
}

View File

@@ -0,0 +1,140 @@
import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow";
import type { Pagination } from "@/app/api/__generated__/models/pagination";
import { formatDuration, formatMicrodollars, formatTokens } from "../helpers";
import { TrackingBadge } from "./TrackingBadge";
function formatLogDate(value: unknown): string {
if (value instanceof Date) return value.toLocaleString();
if (typeof value === "string" || typeof value === "number")
return new Date(value).toLocaleString();
return "-";
}
interface Props {
logs: CostLogRow[];
pagination: Pagination | null;
onPageChange: (page: number) => void;
}
function LogsTable({ logs, pagination, onPageChange }: Props) {
return (
<div className="flex flex-col gap-4">
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-3 py-3">
Time
</th>
<th scope="col" className="px-3 py-3">
User
</th>
<th scope="col" className="px-3 py-3">
Block
</th>
<th scope="col" className="px-3 py-3">
Provider
</th>
<th scope="col" className="px-3 py-3">
Type
</th>
<th scope="col" className="px-3 py-3">
Model
</th>
<th scope="col" className="px-3 py-3 text-right">
Cost
</th>
<th scope="col" className="px-3 py-3 text-right">
Tokens
</th>
<th scope="col" className="px-3 py-3 text-right">
Duration
</th>
<th scope="col" className="px-3 py-3">
Execution
</th>
</tr>
</thead>
<tbody>
{logs.map((log) => (
<tr key={log.id} className="border-b hover:bg-muted">
<td className="whitespace-nowrap px-3 py-2 text-xs">
{formatLogDate(log.created_at)}
</td>
<td className="px-3 py-2 text-xs">
{log.email ||
(log.user_id ? String(log.user_id).slice(0, 8) : "-")}
</td>
<td className="px-3 py-2 text-xs font-medium">
{log.block_name}
</td>
<td className="px-3 py-2 text-xs">{log.provider}</td>
<td className="px-3 py-2 text-xs">
<TrackingBadge trackingType={log.tracking_type} />
</td>
<td className="px-3 py-2 text-xs">{log.model || "-"}</td>
<td className="px-3 py-2 text-right text-xs">
{log.cost_microdollars != null
? formatMicrodollars(Number(log.cost_microdollars))
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.input_tokens != null || log.output_tokens != null
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.duration != null
? formatDuration(Number(log.duration))
: "-"}
</td>
<td className="px-3 py-2 text-xs text-muted-foreground">
{log.graph_exec_id
? String(log.graph_exec_id).slice(0, 8)
: "-"}
</td>
</tr>
))}
{logs.length === 0 && (
<tr>
<td
colSpan={10}
className="px-4 py-8 text-center text-muted-foreground"
>
No logs found
</td>
</tr>
)}
</tbody>
</table>
</div>
{pagination && pagination.total_pages > 1 && (
<div className="flex items-center justify-between px-4">
<span className="text-sm text-muted-foreground">
Page {pagination.current_page} of {pagination.total_pages} (
{pagination.total_items} total)
</span>
<div className="flex gap-2">
<button
disabled={pagination.current_page <= 1}
onClick={() => onPageChange(pagination.current_page - 1)}
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
>
Previous
</button>
<button
disabled={pagination.current_page >= pagination.total_pages}
onClick={() => onPageChange(pagination.current_page + 1)}
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
>
Next
</button>
</div>
</div>
)}
</div>
);
}
export { LogsTable };

View File

@@ -0,0 +1,234 @@
"use client";
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { formatMicrodollars } from "../helpers";
import { SummaryCard } from "./SummaryCard";
import { ProviderTable } from "./ProviderTable";
import { UserTable } from "./UserTable";
import { LogsTable } from "./LogsTable";
import { usePlatformCostContent } from "./usePlatformCostContent";
interface Props {
searchParams: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
};
}
export function PlatformCostContent({ searchParams }: Props) {
const {
dashboard,
logs,
pagination,
loading,
error,
totalEstimatedCost,
tab,
startInput,
setStartInput,
endInput,
setEndInput,
providerInput,
setProviderInput,
userInput,
setUserInput,
rateOverrides,
handleRateOverride,
updateUrl,
handleFilter,
} = usePlatformCostContent(searchParams);
return (
<div className="flex flex-col gap-6">
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
<div className="flex flex-col gap-1">
<label htmlFor="start-date" className="text-sm text-muted-foreground">
Start Date{" "}
<span className="text-xs">
(local time defaults to last 30 days)
</span>
</label>
<input
id="start-date"
type="datetime-local"
className="rounded border px-3 py-1.5 text-sm"
value={startInput}
onChange={(e) => setStartInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label htmlFor="end-date" className="text-sm text-muted-foreground">
End Date <span className="text-xs">(local time)</span>
</label>
<input
id="end-date"
type="datetime-local"
className="rounded border px-3 py-1.5 text-sm"
value={endInput}
onChange={(e) => setEndInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label
htmlFor="provider-filter"
className="text-sm text-muted-foreground"
>
Provider
</label>
<input
id="provider-filter"
type="text"
placeholder="e.g. openai"
className="rounded border px-3 py-1.5 text-sm"
value={providerInput}
onChange={(e) => setProviderInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label
htmlFor="user-id-filter"
className="text-sm text-muted-foreground"
>
User ID
</label>
<input
id="user-id-filter"
type="text"
placeholder="Filter by user"
className="rounded border px-3 py-1.5 text-sm"
value={userInput}
onChange={(e) => setUserInput(e.target.value)}
/>
</div>
<button
onClick={handleFilter}
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
>
Apply
</button>
<button
onClick={() => {
setStartInput("");
setEndInput("");
setProviderInput("");
setUserInput("");
updateUrl({
start: "",
end: "",
provider: "",
user_id: "",
page: "1",
});
}}
className="rounded border px-4 py-1.5 text-sm hover:bg-muted"
>
Clear
</button>
</div>
{error && (
<Alert variant="error">
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
{loading ? (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
{[...Array(4)].map((_, i) => (
<Skeleton key={i} className="h-20 rounded-lg" />
))}
</div>
<Skeleton className="h-8 w-48 rounded" />
<Skeleton className="h-64 rounded-lg" />
</div>
) : (
<>
{dashboard && (
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
<SummaryCard
label="Known Cost"
value={formatMicrodollars(dashboard.total_cost_microdollars)}
subtitle="From providers that report USD cost"
/>
<SummaryCard
label="Estimated Total"
value={formatMicrodollars(totalEstimatedCost)}
subtitle="Including per-run cost estimates"
/>
<SummaryCard
label="Total Requests"
value={dashboard.total_requests.toLocaleString()}
/>
<SummaryCard
label="Active Users"
value={dashboard.total_users.toLocaleString()}
/>
</div>
)}
<div
role="tablist"
aria-label="Cost view tabs"
className="flex gap-2 border-b"
>
{["overview", "by-user", "logs"].map((t) => (
<button
key={t}
id={`tab-${t}`}
role="tab"
aria-selected={tab === t}
aria-controls={`tabpanel-${t}`}
onClick={() => updateUrl({ tab: t, page: "1" })}
className={`px-4 py-2 text-sm font-medium ${tab === t ? "border-b-2 border-primary text-primary" : "text-muted-foreground hover:text-foreground"}`}
>
{t === "overview"
? "By Provider"
: t === "by-user"
? "By User"
: "Raw Logs"}
</button>
))}
</div>
{tab === "overview" && dashboard && (
<div
role="tabpanel"
id="tabpanel-overview"
aria-labelledby="tab-overview"
>
<ProviderTable
data={dashboard.by_provider}
rateOverrides={rateOverrides}
onRateOverride={handleRateOverride}
/>
</div>
)}
{tab === "by-user" && dashboard && (
<div
role="tabpanel"
id="tabpanel-by-user"
aria-labelledby="tab-by-user"
>
<UserTable data={dashboard.by_user} />
</div>
)}
{tab === "logs" && (
<div role="tabpanel" id="tabpanel-logs" aria-labelledby="tab-logs">
<LogsTable
logs={logs}
pagination={pagination}
onPageChange={(p) => updateUrl({ page: p.toString() })}
/>
</div>
)}
</>
)}
</div>
);
}

View File

@@ -0,0 +1,131 @@
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
import {
defaultRateFor,
estimateCostForRow,
formatMicrodollars,
rateKey,
rateUnitLabel,
trackingValue,
} from "../helpers";
import { TrackingBadge } from "./TrackingBadge";
interface Props {
data: ProviderCostSummary[];
rateOverrides: Record<string, number>;
onRateOverride: (key: string, val: number | null) => void;
}
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
return (
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-4 py-3">
Provider
</th>
<th scope="col" className="px-4 py-3">
Type
</th>
<th scope="col" className="px-4 py-3 text-right">
Usage
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
<th scope="col" className="px-4 py-3 text-right">
Known Cost
</th>
<th scope="col" className="px-4 py-3 text-right">
Est. Cost
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Per-session only"
>
Rate <span className="text-[10px] font-normal">(unsaved)</span>
</th>
</tr>
</thead>
<tbody>
{data.map((row) => {
const est = estimateCostForRow(row, rateOverrides);
const tt = row.tracking_type || "per_run";
// For cost_usd rows the provider reports USD directly so rate
// input doesn't apply; otherwise show an editable input.
const showRateInput = tt !== "cost_usd";
const key = rateKey(row.provider, tt);
const fallback = defaultRateFor(row.provider, tt);
const currentRate = rateOverrides[key] ?? fallback;
return (
<tr key={key} className="border-b hover:bg-muted">
<td className="px-4 py-3 font-medium">{row.provider}</td>
<td className="px-4 py-3">
<TrackingBadge trackingType={row.tracking_type} />
</td>
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
<td className="px-4 py-3 text-right">
{row.total_cost_microdollars > 0
? formatMicrodollars(row.total_cost_microdollars)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{est !== null ? (
formatMicrodollars(est)
) : (
<span className="text-muted-foreground">-</span>
)}
</td>
<td className="px-4 py-2 text-right">
{showRateInput ? (
<div className="flex items-center justify-end gap-1">
<input
type="number"
step="0.0001"
min="0"
aria-label={`Rate for ${row.provider} (${tt})`}
className="w-24 rounded border px-2 py-1 text-right text-xs"
placeholder={fallback !== null ? String(fallback) : "0"}
value={currentRate ?? ""}
onChange={(e) => {
const val = parseFloat(e.target.value);
if (!isNaN(val)) onRateOverride(key, val);
else if (e.target.value === "")
onRateOverride(key, null);
}}
/>
<span
className="text-[10px] text-muted-foreground"
title={rateUnitLabel(tt)}
>
{rateUnitLabel(tt)}
</span>
</div>
) : (
<span className="text-xs text-muted-foreground">auto</span>
)}
</td>
</tr>
);
})}
{data.length === 0 && (
<tr>
<td
colSpan={7}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet
</td>
</tr>
)}
</tbody>
</table>
</div>
);
}
export { ProviderTable };

View File

@@ -0,0 +1,19 @@
interface Props {
label: string;
value: string;
subtitle?: string;
}
function SummaryCard({ label, value, subtitle }: Props) {
return (
<div className="rounded-lg border p-4">
<div className="text-sm text-muted-foreground">{label}</div>
<div className="text-2xl font-bold">{value}</div>
{subtitle && (
<div className="mt-1 text-xs text-muted-foreground">{subtitle}</div>
)}
</div>
);
}
export { SummaryCard };

View File

@@ -0,0 +1,25 @@
function TrackingBadge({
trackingType,
}: {
trackingType: string | null | undefined;
}) {
const colors: Record<string, string> = {
cost_usd: "bg-green-500/10 text-green-700",
tokens: "bg-blue-500/10 text-blue-700",
characters: "bg-purple-500/10 text-purple-700",
sandbox_seconds: "bg-orange-500/10 text-orange-700",
walltime_seconds: "bg-orange-500/10 text-orange-700",
items: "bg-pink-500/10 text-pink-700",
per_run: "bg-muted text-muted-foreground",
};
const label = trackingType || "per_run";
return (
<span
className={`inline-block rounded px-1.5 py-0.5 text-[10px] font-medium ${colors[label] || colors.per_run}`}
>
{label}
</span>
);
}
export { TrackingBadge };

View File

@@ -0,0 +1,75 @@
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
import { formatMicrodollars, formatTokens } from "../helpers";
interface Props {
data: PlatformCostDashboard["by_user"];
}
function UserTable({ data }: Props) {
return (
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-4 py-3">
User
</th>
<th scope="col" className="px-4 py-3 text-right">
Known Cost
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
<th scope="col" className="px-4 py-3 text-right">
Input Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Output Tokens
</th>
</tr>
</thead>
<tbody>
{data.map((row, idx) => (
<tr
key={row.user_id ?? `unknown-${idx}`}
className="border-b hover:bg-muted"
>
<td className="px-4 py-3">
<div className="font-medium">{row.email || "Unknown"}</div>
<div className="text-xs text-muted-foreground">
{row.user_id}
</div>
</td>
<td className="px-4 py-3 text-right">
{row.total_cost_microdollars > 0
? formatMicrodollars(row.total_cost_microdollars)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
<td className="px-4 py-3 text-right">
{formatTokens(row.total_input_tokens)}
</td>
<td className="px-4 py-3 text-right">
{formatTokens(row.total_output_tokens)}
</td>
</tr>
))}
{data.length === 0 && (
<tr>
<td
colSpan={5}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet
</td>
</tr>
)}
</tbody>
</table>
</div>
);
}
export { UserTable };

View File

@@ -0,0 +1,136 @@
"use client";
import { useRouter, useSearchParams } from "next/navigation";
import { useState } from "react";
import {
useGetV2GetPlatformCostDashboard,
useGetV2GetPlatformCostLogs,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers";
interface InitialSearchParams {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
}
export function usePlatformCostContent(searchParams: InitialSearchParams) {
const router = useRouter();
const urlParams = useSearchParams();
const tab = urlParams.get("tab") || searchParams.tab || "overview";
const page = parseInt(urlParams.get("page") || searchParams.page || "1", 10);
const startDate = urlParams.get("start") || searchParams.start || "";
const endDate = urlParams.get("end") || searchParams.end || "";
const providerFilter =
urlParams.get("provider") || searchParams.provider || "";
const userFilter = urlParams.get("user_id") || searchParams.user_id || "";
const [startInput, setStartInput] = useState(toLocalInput(startDate));
const [endInput, setEndInput] = useState(toLocalInput(endDate));
const [providerInput, setProviderInput] = useState(providerFilter);
const [userInput, setUserInput] = useState(userFilter);
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
{},
);
// Pass ISO date strings through `as unknown as Date` so Orval's URL builder
// forwards them as-is. Date.toString() produces a format FastAPI rejects;
// strings pass through .toString() unchanged.
const filterParams = {
start: (startDate || undefined) as unknown as Date | undefined,
end: (endDate || undefined) as unknown as Date | undefined,
provider: providerFilter || undefined,
user_id: userFilter || undefined,
};
const {
data: dashboard,
isLoading: dashLoading,
error: dashError,
} = useGetV2GetPlatformCostDashboard(filterParams, {
query: { select: okData },
});
const {
data: logsResponse,
isLoading: logsLoading,
error: logsError,
} = useGetV2GetPlatformCostLogs(
{ ...filterParams, page, page_size: 50 },
{ query: { select: okData } },
);
const loading = dashLoading || logsLoading;
const error = dashError
? dashError instanceof Error
? dashError.message
: "Failed to load dashboard"
: logsError
? logsError instanceof Error
? logsError.message
: "Failed to load logs"
: null;
function updateUrl(overrides: Record<string, string>) {
const params = new URLSearchParams(urlParams.toString());
for (const [k, v] of Object.entries(overrides)) {
if (v) params.set(k, v);
else params.delete(k);
}
router.push(`/admin/platform-costs?${params.toString()}`);
}
function handleFilter() {
updateUrl({
start: toUtcIso(startInput),
end: toUtcIso(endInput),
provider: providerInput,
user_id: userInput,
page: "1",
});
}
function handleRateOverride(key: string, val: number | null) {
setRateOverrides((prev) => {
if (val === null) {
const { [key]: _, ...rest } = prev;
return rest;
}
return { ...prev, [key]: val };
});
}
const totalEstimatedCost =
dashboard?.by_provider.reduce((sum, row) => {
const est = estimateCostForRow(row, rateOverrides);
return sum + (est ?? 0);
}, 0) ?? 0;
return {
dashboard: dashboard ?? null,
logs: logsResponse?.logs ?? [],
pagination: logsResponse?.pagination ?? null,
loading,
error,
totalEstimatedCost,
tab,
page,
startInput,
setStartInput,
endInput,
setEndInput,
providerInput,
setProviderInput,
userInput,
setUserInput,
rateOverrides,
handleRateOverride,
updateUrl,
handleFilter,
};
}

View File

@@ -0,0 +1,204 @@
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
const MICRODOLLARS_PER_USD = 1_000_000;
// Per-request cost estimates (USD) for providers billed per API call.
export const DEFAULT_COST_PER_RUN: Record<string, number> = {
google_maps: 0.032, // $0.032/request - Google Maps Places API
ideogram: 0.08, // $0.08/image - Ideogram standard generation
nvidia: 0.0, // Free tier - NVIDIA NIM deepfake detection
screenshotone: 0.01, // ~$0.01/screenshot - ScreenshotOne starter
zerobounce: 0.008, // $0.008/validation - ZeroBounce
mem0: 0.01, // ~$0.01/request - Mem0
openweathermap: 0.0, // Free tier
webshare_proxy: 0.0, // Flat subscription
enrichlayer: 0.1, // ~$0.10/profile lookup
jina: 0.0, // Free tier
};
export const DEFAULT_COST_PER_1K_TOKENS: Record<string, number> = {
openai: 0.005,
anthropic: 0.008,
groq: 0.0003,
ollama: 0.0,
aiml_api: 0.005,
llama_api: 0.003,
v0: 0.005,
};
// Per-character rates (USD / 1K characters) for TTS providers.
export const DEFAULT_COST_PER_1K_CHARS: Record<string, number> = {
unreal_speech: 0.008, // ~$8/1M chars on Starter
elevenlabs: 0.18, // ~$0.18/1K chars on Starter
d_id: 0.04, // ~$0.04/1K chars estimated
};
// Per-item rates (USD / item) for item-count billed APIs.
export const DEFAULT_COST_PER_ITEM: Record<string, number> = {
google_maps: 0.017, // avg of $0.032 nearby + ~$0.015 detail enrich
apollo: 0.02, // ~$0.02/contact on low-volume tiers
smartlead: 0.001, // ~$0.001/lead added
};
// Per-second rates (USD / second) for duration-billed providers.
export const DEFAULT_COST_PER_SECOND: Record<string, number> = {
e2b: 0.000014, // $0.000014/sec (2-core sandbox)
fal: 0.0005, // varies by model, conservative
replicate: 0.001, // varies by hardware
revid: 0.01, // per-second of video
};
export function toDateOrUndefined(val?: string): Date | undefined {
if (!val) return undefined;
const d = new Date(val);
return isNaN(d.getTime()) ? undefined : d;
}
export function formatMicrodollars(microdollars: number) {
return `$${(microdollars / MICRODOLLARS_PER_USD).toFixed(4)}`;
}
export function formatTokens(tokens: number) {
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`;
return tokens.toString();
}
export function formatDuration(seconds: number) {
if (seconds >= 3600) return `${(seconds / 3600).toFixed(1)}h`;
if (seconds >= 60) return `${(seconds / 60).toFixed(1)}m`;
return `${seconds.toFixed(1)}s`;
}
// Unit label for each tracking type — what the rate input represents.
export function rateUnitLabel(trackingType: string | null | undefined): string {
switch (trackingType) {
case "tokens":
return "$/1K tokens";
case "characters":
return "$/1K chars";
case "items":
return "$/item";
case "sandbox_seconds":
case "walltime_seconds":
return "$/second";
case "per_run":
return "$/run";
default:
return "";
}
}
// Default rate for a (provider, tracking_type) pair.
export function defaultRateFor(
provider: string,
trackingType: string | null | undefined,
): number | null {
switch (trackingType) {
case "tokens":
return DEFAULT_COST_PER_1K_TOKENS[provider] ?? null;
case "characters":
return DEFAULT_COST_PER_1K_CHARS[provider] ?? null;
case "items":
return DEFAULT_COST_PER_ITEM[provider] ?? null;
case "sandbox_seconds":
case "walltime_seconds":
return DEFAULT_COST_PER_SECOND[provider] ?? null;
case "per_run":
return DEFAULT_COST_PER_RUN[provider] ?? null;
default:
return null;
}
}
// Overrides are keyed on `${provider}:${tracking_type}` since the same
// provider can have multiple rows with different billing models.
export function rateKey(
provider: string,
trackingType: string | null | undefined,
): string {
return `${provider}:${trackingType ?? "per_run"}`;
}
export function estimateCostForRow(
row: ProviderCostSummary,
rateOverrides: Record<string, number>,
) {
const tt = row.tracking_type || "per_run";
// Providers that report USD directly: use known cost.
if (tt === "cost_usd") return row.total_cost_microdollars;
// Prefer the real USD the provider reported if any, but only for token paths
// where OpenRouter piggybacks on the tokens row via x-total-cost.
if (tt === "tokens" && row.total_cost_microdollars > 0) {
return row.total_cost_microdollars;
}
const rate =
rateOverrides[rateKey(row.provider, tt)] ??
defaultRateFor(row.provider, tt);
if (rate === null || rate === undefined) return null;
// Compute the amount for this tracking type, then multiply by rate.
let amount: number;
switch (tt) {
case "tokens":
// Rate is per-1K tokens.
amount = (row.total_input_tokens + row.total_output_tokens) / 1000;
break;
case "characters":
// Rate is per-1K chars. trackingAmount aggregates char counts.
amount = (row.total_tracking_amount || 0) / 1000;
break;
case "items":
amount = row.total_tracking_amount || 0;
break;
case "sandbox_seconds":
case "walltime_seconds":
amount = row.total_duration_seconds || 0;
break;
case "per_run":
amount = row.request_count;
break;
default:
return row.total_cost_microdollars > 0
? row.total_cost_microdollars
: null;
}
return Math.round(rate * amount * MICRODOLLARS_PER_USD);
}
export function trackingValue(row: ProviderCostSummary) {
const tt = row.tracking_type || "per_run";
if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars);
if (tt === "tokens") {
const tokens = row.total_input_tokens + row.total_output_tokens;
return `${formatTokens(tokens)} tokens`;
}
if (tt === "sandbox_seconds" || tt === "walltime_seconds")
return formatDuration(row.total_duration_seconds || 0);
if (tt === "characters")
return `${formatTokens(Math.round(row.total_tracking_amount || 0))} chars`;
if (tt === "items")
return `${Math.round(row.total_tracking_amount || 0).toLocaleString()} items`;
return `${row.request_count.toLocaleString()} runs`;
}
// URL holds UTC ISO; datetime-local inputs need local "YYYY-MM-DDTHH:mm".
export function toLocalInput(iso: string) {
if (!iso) return "";
const d = new Date(iso);
if (isNaN(d.getTime())) return "";
const pad = (n: number) => String(n).padStart(2, "0");
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())}T${pad(d.getHours())}:${pad(d.getMinutes())}`;
}
// datetime-local emits naive local time; convert to UTC ISO so the
// backend filter window matches what the admin sees in their browser.
export function toUtcIso(local: string) {
if (!local) return "";
const d = new Date(local);
return isNaN(d.getTime()) ? "" : d.toISOString();
}

View File

@@ -0,0 +1,50 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
import { PlatformCostContent } from "./components/PlatformCostContent";
type SearchParams = {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
};
function PlatformCostDashboard({
searchParams,
}: {
searchParams: SearchParams;
}) {
return (
<div className="mx-auto p-6">
<div className="flex flex-col gap-4">
<div>
<h1 className="text-3xl font-bold">Platform Costs</h1>
<p className="text-muted-foreground">
Track real API costs incurred by system credentials across providers
</p>
</div>
<Suspense
key={JSON.stringify(searchParams)}
fallback={
<div className="py-10 text-center">Loading cost data...</div>
}
>
<PlatformCostContent searchParams={searchParams} />
</Suspense>
</div>
</div>
);
}
export default async function PlatformCostDashboardPage({
searchParams,
}: {
searchParams: Promise<SearchParams>;
}) {
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedDashboard = await withAdminAccess(PlatformCostDashboard);
return <ProtectedDashboard searchParams={await searchParams} />;
}

View File

@@ -40,14 +40,14 @@ export const ContentRenderer: React.FC<{
!shortContent
) {
return (
<div className="overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words">
<div className="overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words">
{renderer?.render(value, metadata)}
</div>
);
}
return (
<div className="overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs">
<div className="overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs">
<TextRenderer value={value} truncateLengthLimit={200} />
</div>
);

View File

@@ -8,6 +8,7 @@ import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { UploadSimple } from "@phosphor-icons/react";
import dynamic from "next/dynamic";
import { useCallback, useEffect, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -20,6 +21,14 @@ import { RateLimitResetDialog } from "./components/RateLimitResetDialog/RateLimi
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
import { useCopilotPage } from "./useCopilotPage";
const ArtifactPanel = dynamic(
() =>
import("./components/ArtifactPanel/ArtifactPanel").then(
(m) => m.ArtifactPanel,
),
{ ssr: false },
);
export function CopilotPage() {
const [isDragging, setIsDragging] = useState(false);
const [droppedFiles, setDroppedFiles] = useState<File[]>([]);
@@ -80,6 +89,10 @@ export function CopilotPage() {
isUploadingFiles,
isUserLoading,
isLoggedIn,
// Pagination
hasMoreMessages,
isLoadingMore,
loadMore,
// Mobile drawer
isMobile,
isDrawerOpen,
@@ -116,6 +129,7 @@ export function CopilotPage() {
const resetCost = usage?.reset_cost;
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const isArtifactsEnabled = useGetFlag(Flag.ARTIFACTS);
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
const hasInsufficientCredits =
credits !== null && resetCost != null && credits < resetCost;
@@ -150,48 +164,55 @@ export function CopilotPage() {
className="h-[calc(100vh-72px)] min-h-0"
>
{!isMobile && <ChatSidebar />}
<div
className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0"
onDragEnter={handleDragEnter}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDrop}
>
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<NotificationBanner />
{/* Drop overlay */}
<div className="flex h-full w-full flex-row overflow-hidden">
<div
className={cn(
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
isDragging ? "opacity-100" : "opacity-0",
)}
className="relative flex min-w-0 flex-1 flex-col overflow-hidden bg-[#f8f8f9] px-0"
onDragEnter={handleDragEnter}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDrop}
>
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
<span className="text-lg font-medium text-violet-600">
Drop files here
</span>
</div>
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
status={status}
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isSessionError={isSessionError}
isCreatingSession={isCreatingSession}
isReconnecting={isReconnecting}
isSyncing={isSyncing}
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}
/>
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<NotificationBanner />
{/* Drop overlay */}
<div
className={cn(
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
isDragging ? "opacity-100" : "opacity-0",
)}
>
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
<span className="text-lg font-medium text-violet-600">
Drop files here
</span>
</div>
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
status={status}
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isSessionError={isSessionError}
isCreatingSession={isCreatingSession}
isReconnecting={isReconnecting}
isSyncing={isSyncing}
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
isUploadingFiles={isUploadingFiles}
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={loadMore}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}
/>
</div>
</div>
{!isMobile && isArtifactsEnabled && <ArtifactPanel />}
</div>
{isMobile && isArtifactsEnabled && <ArtifactPanel mobile />}
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}

View File

@@ -0,0 +1,114 @@
"use client";
import { toast } from "@/components/molecules/Toast/use-toast";
import { cn } from "@/lib/utils";
import { CaretRight, DownloadSimple } from "@phosphor-icons/react";
import type { ArtifactRef } from "../../store";
import { useCopilotUIStore } from "../../store";
import { downloadArtifact } from "../ArtifactPanel/downloadArtifact";
import { classifyArtifact } from "../ArtifactPanel/helpers";
interface Props {
artifact: ArtifactRef;
}
function formatSize(bytes?: number): string {
if (!bytes) return "";
if (bytes < 1024) return `${bytes} B`;
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
}
export function ArtifactCard({ artifact }: Props) {
const activeID = useCopilotUIStore((s) => s.artifactPanel.activeArtifact?.id);
const isOpen = useCopilotUIStore((s) => s.artifactPanel.isOpen);
const openArtifact = useCopilotUIStore((s) => s.openArtifact);
const isActive = isOpen && activeID === artifact.id;
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
artifact.sizeBytes,
);
const Icon = classification.icon;
function handleDownloadOnly() {
downloadArtifact(artifact).catch(() => {
toast({
title: "Download failed",
description: "Couldn't fetch the file.",
variant: "destructive",
});
});
}
if (!classification.openable) {
return (
<button
type="button"
onClick={handleDownloadOnly}
className="my-1 flex w-full items-center gap-3 rounded-lg border border-zinc-200 bg-white px-3 py-2.5 text-left transition-colors hover:bg-zinc-50"
>
<Icon size={20} className="shrink-0 text-zinc-400" />
<div className="min-w-0 flex-1">
<p className="truncate text-sm font-medium text-zinc-900">
{artifact.title}
</p>
<p className="text-xs text-zinc-400">
{classification.label}
{artifact.sizeBytes
? ` \u2022 ${formatSize(artifact.sizeBytes)}`
: ""}
</p>
</div>
<DownloadSimple size={16} className="shrink-0 text-zinc-400" />
</button>
);
}
return (
<button
type="button"
onClick={() => openArtifact(artifact)}
className={cn(
"my-1 flex w-full items-center gap-3 rounded-lg border bg-white px-3 py-2.5 text-left transition-colors hover:bg-zinc-50",
isActive ? "border-violet-300 bg-violet-50/50" : "border-zinc-200",
)}
>
<Icon
size={20}
className={cn(
"shrink-0",
isActive ? "text-violet-500" : "text-zinc-400",
)}
/>
<div className="min-w-0 flex-1">
<p className="truncate text-sm font-medium text-zinc-900">
{artifact.title}
</p>
<p className="text-xs text-zinc-400">
<span
className={cn(
"inline-block rounded-full px-1.5 py-0.5 text-xs font-medium",
artifact.origin === "user-upload"
? "bg-blue-50 text-blue-500"
: "bg-violet-50 text-violet-500",
)}
>
{classification.label}
</span>
{artifact.sizeBytes
? ` \u2022 ${formatSize(artifact.sizeBytes)}`
: ""}
</p>
</div>
<CaretRight
size={16}
className={cn(
"shrink-0",
isActive ? "text-violet-400" : "text-zinc-300",
)}
/>
</button>
);
}

View File

@@ -0,0 +1,125 @@
"use client";
import {
Sheet,
SheetContent,
SheetHeader,
SheetTitle,
} from "@/components/ui/sheet";
import { AnimatePresence, motion } from "framer-motion";
import { ArtifactContent } from "./components/ArtifactContent";
import { ArtifactDragHandle } from "./components/ArtifactDragHandle";
import { ArtifactMinimizedStrip } from "./components/ArtifactMinimizedStrip";
import { ArtifactPanelHeader } from "./components/ArtifactPanelHeader";
import { useArtifactPanel } from "./useArtifactPanel";
interface Props {
mobile?: boolean;
}
export function ArtifactPanel({ mobile }: Props) {
const {
isOpen,
isMinimized,
isMaximized,
activeArtifact,
history,
effectiveWidth,
isSourceView,
classification,
setIsSourceView,
closeArtifactPanel,
minimizeArtifactPanel,
maximizeArtifactPanel,
restoreArtifactPanel,
setArtifactPanelWidth,
goBackArtifact,
canCopy,
handleCopy,
handleDownload,
} = useArtifactPanel();
if (!activeArtifact || !classification) return null;
const headerProps = {
artifact: activeArtifact,
classification,
canGoBack: history.length > 0,
isMaximized,
isSourceView,
hasSourceToggle: classification.hasSourceToggle,
mobile: !!mobile,
canCopy,
onBack: goBackArtifact,
onClose: closeArtifactPanel,
onMinimize: minimizeArtifactPanel,
onMaximize: maximizeArtifactPanel,
onRestore: restoreArtifactPanel,
onCopy: handleCopy,
onDownload: handleDownload,
onSourceToggle: setIsSourceView,
};
// Mobile: fullscreen Sheet overlay
if (mobile) {
return (
<Sheet
open={isOpen}
onOpenChange={(open) => !open && closeArtifactPanel()}
>
<SheetContent
side="right"
className="flex w-full flex-col p-0 sm:max-w-full"
>
<SheetHeader className="sr-only">
<SheetTitle>{activeArtifact.title}</SheetTitle>
</SheetHeader>
<ArtifactPanelHeader {...headerProps} />
<ArtifactContent
artifact={activeArtifact}
isSourceView={isSourceView}
classification={classification}
/>
</SheetContent>
</Sheet>
);
}
// Minimized strip
if (isOpen && isMinimized) {
return (
<ArtifactMinimizedStrip
artifact={activeArtifact}
classification={classification}
onExpand={restoreArtifactPanel}
/>
);
}
// Keep AnimatePresence mounted across the open→closed transition so the
// exit animation on the motion.div has a chance to run.
return (
<AnimatePresence>
{isOpen && (
<motion.div
key="artifact-panel"
data-artifact-panel
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.25, ease: "easeInOut" }}
className="relative flex h-full flex-col overflow-hidden border-l border-zinc-200 bg-white"
style={{ width: effectiveWidth }}
>
<ArtifactDragHandle onWidthChange={setArtifactPanelWidth} />
<ArtifactPanelHeader {...headerProps} />
<ArtifactContent
artifact={activeArtifact}
isSourceView={isSourceView}
classification={classification}
/>
</motion.div>
)}
</AnimatePresence>
);
}

View File

@@ -0,0 +1,198 @@
"use client";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
import { Suspense } from "react";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import { ArtifactSkeleton } from "./ArtifactSkeleton";
import {
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "@/lib/iframe-sandbox-csp";
import { useArtifactContent } from "./useArtifactContent";
interface Props {
artifact: ArtifactRef;
isSourceView: boolean;
classification: ArtifactClassification;
}
function ArtifactContentLoader({
artifact,
isSourceView,
classification,
}: Props) {
const { content, pdfUrl, isLoading, error, scrollRef, retry } =
useArtifactContent(artifact, classification);
if (isLoading) {
return <ArtifactSkeleton extraLine />;
}
if (error) {
return (
<div
role="alert"
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm text-zinc-500">Failed to load content</p>
<p className="text-xs text-zinc-400">{error}</p>
<button
type="button"
onClick={retry}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Try again
</button>
</div>
);
}
return (
<div ref={scrollRef} className="flex-1 overflow-y-auto">
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
</div>
);
}
function ArtifactRenderer({
artifact,
content,
pdfUrl,
isSourceView,
classification,
}: {
artifact: ArtifactRef;
content: string | null;
pdfUrl: string | null;
isSourceView: boolean;
classification: ArtifactClassification;
}) {
// Image: render directly from URL (no content fetch)
if (classification.type === "image") {
return (
<div className="flex items-center justify-center p-4">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={artifact.sourceUrl}
alt={artifact.title}
className="max-h-full max-w-full object-contain"
/>
</div>
);
}
if (classification.type === "pdf" && pdfUrl) {
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
// (Chromium bug #413851). The blob URL has a null origin so it can't
// access the parent page regardless.
return (
<iframe src={pdfUrl} className="h-full w-full" title={artifact.title} />
);
}
if (content === null) return null;
// Source view: always show raw text
if (isSourceView) {
return (
<pre className="whitespace-pre-wrap break-words p-4 font-mono text-sm text-zinc-800">
{content}
</pre>
);
}
if (classification.type === "html") {
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
const wrapped = wrapWithHeadInjection(content, tailwindScript);
return (
<iframe
sandbox="allow-scripts"
srcDoc={wrapped}
className="h-full w-full border-0"
title={artifact.title}
/>
);
}
if (classification.type === "react") {
return <ArtifactReactPreview source={content} title={artifact.title} />;
}
// Code: pass with explicit type metadata so CodeRenderer matches
// (prevents higher-priority MarkdownRenderer from claiming it)
if (classification.type === "code") {
const ext = artifact.title.split(".").pop() ?? "";
const codeMeta = {
mimeType: artifact.mimeType ?? undefined,
filename: artifact.title,
type: "code",
language: ext,
};
return <div className="p-4">{codeRenderer.render(content, codeMeta)}</div>;
}
// JSON: parse first so the JSONRenderer gets an object, not a string
// (prevents higher-priority MarkdownRenderer from claiming it)
if (classification.type === "json") {
try {
const parsed = JSON.parse(content);
const jsonMeta = {
mimeType: "application/json",
type: "json",
filename: artifact.title,
};
const jsonRenderer = globalRegistry.getRenderer(parsed, jsonMeta);
if (jsonRenderer) {
return (
<div className="p-4">{jsonRenderer.render(parsed, jsonMeta)}</div>
);
}
} catch {
// invalid JSON — fall through to plain text
}
}
// CSV: pass with explicit metadata so CSVRenderer matches
if (classification.type === "csv") {
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
if (csvRenderer) {
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;
}
}
// Try the global renderer registry
const metadata = {
mimeType: artifact.mimeType ?? undefined,
filename: artifact.title,
};
const renderer = globalRegistry.getRenderer(content, metadata);
if (renderer) {
return <div className="p-4">{renderer.render(content, metadata)}</div>;
}
// Fallback: plain text
return (
<pre className="whitespace-pre-wrap break-words p-4 font-mono text-sm text-zinc-800">
{content}
</pre>
);
}
export function ArtifactContent(props: Props) {
return (
<Suspense fallback={<ArtifactSkeleton />}>
<ArtifactContentLoader {...props} />
</Suspense>
);
}

View File

@@ -0,0 +1,93 @@
"use client";
import { cn } from "@/lib/utils";
import { useEffect, useRef, useState } from "react";
import { DEFAULT_PANEL_WIDTH } from "../../../store";
interface Props {
onWidthChange: (width: number) => void;
minWidth?: number;
maxWidthPercent?: number;
}
export function ArtifactDragHandle({
onWidthChange,
minWidth = 320,
maxWidthPercent = 85,
}: Props) {
const [isDragging, setIsDragging] = useState(false);
const startXRef = useRef(0);
const startWidthRef = useRef(0);
// Use refs for the callback + bounds so the drag listeners can read the
// latest values without having to detach/reattach between re-renders.
const onWidthChangeRef = useRef(onWidthChange);
const minWidthRef = useRef(minWidth);
const maxWidthPercentRef = useRef(maxWidthPercent);
onWidthChangeRef.current = onWidthChange;
minWidthRef.current = minWidth;
maxWidthPercentRef.current = maxWidthPercent;
// Attach document listeners only while dragging, and always tear them down
// on unmount — otherwise closing the panel mid-drag leaves listeners bound
// to a handler that calls setState on the unmounted component.
useEffect(() => {
if (!isDragging) return;
function handlePointerMove(moveEvent: PointerEvent) {
const delta = startXRef.current - moveEvent.clientX;
const maxWidth = window.innerWidth * (maxWidthPercentRef.current / 100);
const newWidth = Math.min(
maxWidth,
Math.max(minWidthRef.current, startWidthRef.current + delta),
);
onWidthChangeRef.current(newWidth);
}
function handlePointerUp() {
setIsDragging(false);
}
document.addEventListener("pointermove", handlePointerMove);
document.addEventListener("pointerup", handlePointerUp);
document.addEventListener("pointercancel", handlePointerUp);
return () => {
document.removeEventListener("pointermove", handlePointerMove);
document.removeEventListener("pointerup", handlePointerUp);
document.removeEventListener("pointercancel", handlePointerUp);
};
}, [isDragging]);
function handlePointerDown(e: React.PointerEvent) {
e.preventDefault();
startXRef.current = e.clientX;
// Get the panel's current width from its parent
const panel = (e.target as HTMLElement).closest(
"[data-artifact-panel]",
) as HTMLElement | null;
startWidthRef.current = panel?.offsetWidth ?? DEFAULT_PANEL_WIDTH;
setIsDragging(true);
}
return (
// 12px transparent hit target with the visible 1px line centered inside
// (WCAG-compliant, matches ~8-12px conventions of other resizable panels).
<div
role="separator"
aria-orientation="vertical"
aria-label="Resize panel"
className={cn(
"group absolute -left-1.5 top-0 z-10 flex h-full w-3 cursor-col-resize items-stretch justify-center",
)}
onPointerDown={handlePointerDown}
>
<div
className={cn(
"h-full w-px bg-transparent transition-colors group-hover:w-0.5 group-hover:bg-violet-400",
isDragging && "w-0.5 bg-violet-500",
)}
/>
</div>
);
}

View File

@@ -0,0 +1,47 @@
"use client";
import { ArrowsOutSimple } from "@phosphor-icons/react";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
interface Props {
artifact: ArtifactRef;
classification: ArtifactClassification;
onExpand: () => void;
}
export function ArtifactMinimizedStrip({
artifact,
classification,
onExpand,
}: Props) {
const Icon = classification.icon;
return (
<div className="flex h-full w-10 flex-col items-center border-l border-zinc-200 bg-white pt-3">
<button
type="button"
onClick={onExpand}
className="rounded p-1.5 text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
title="Expand panel"
>
<ArrowsOutSimple size={16} />
</button>
<div className="mt-3 text-zinc-400">
<Icon size={16} />
</div>
<span
className="mt-2 text-xs text-zinc-400"
style={{
writingMode: "vertical-rl",
textOrientation: "mixed",
maxHeight: "120px",
overflow: "hidden",
textOverflow: "ellipsis",
}}
>
{artifact.title}
</span>
</div>
);
}

View File

@@ -0,0 +1,138 @@
"use client";
import { cn } from "@/lib/utils";
import {
ArrowLeft,
ArrowsIn,
ArrowsOut,
Copy,
DownloadSimple,
Minus,
X,
} from "@phosphor-icons/react";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { SourceToggle } from "./SourceToggle";
interface Props {
artifact: ArtifactRef;
classification: ArtifactClassification;
canGoBack: boolean;
isMaximized: boolean;
isSourceView: boolean;
hasSourceToggle: boolean;
mobile?: boolean;
canCopy?: boolean;
onBack: () => void;
onClose: () => void;
onMinimize: () => void;
onMaximize: () => void;
onRestore: () => void;
onCopy: () => void;
onDownload: () => void;
onSourceToggle: (isSource: boolean) => void;
}
function HeaderButton({
onClick,
title,
children,
}: {
onClick: () => void;
title: string;
children: React.ReactNode;
}) {
return (
<button
type="button"
onClick={onClick}
title={title}
aria-label={title}
className="rounded p-1.5 text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
{children}
</button>
);
}
export function ArtifactPanelHeader({
artifact,
classification,
canGoBack,
isMaximized,
isSourceView,
hasSourceToggle,
mobile,
canCopy = true,
onBack,
onClose,
onMinimize,
onMaximize,
onRestore,
onCopy,
onDownload,
onSourceToggle,
}: Props) {
const Icon = classification.icon;
return (
<div className="sticky top-0 z-10 flex items-center gap-2 border-b border-zinc-200 bg-white px-3 py-2">
{/* Left section */}
<div className="flex min-w-0 flex-1 items-center gap-2">
{canGoBack && (
<HeaderButton onClick={onBack} title="Back">
<ArrowLeft size={16} />
</HeaderButton>
)}
<Icon size={16} className="shrink-0 text-zinc-400" />
<span className="truncate text-sm font-medium text-zinc-900">
{artifact.title}
</span>
<span
className={cn(
"shrink-0 rounded-full px-2 py-0.5 text-xs font-medium",
artifact.origin === "user-upload"
? "bg-blue-50 text-blue-600"
: "bg-violet-50 text-violet-600",
)}
>
{classification.label}
</span>
</div>
{/* Right section */}
<div className="flex items-center gap-1">
{hasSourceToggle && (
<SourceToggle isSourceView={isSourceView} onToggle={onSourceToggle} />
)}
{canCopy && (
<HeaderButton onClick={onCopy} title="Copy">
<Copy size={16} />
</HeaderButton>
)}
<HeaderButton onClick={onDownload} title="Download">
<DownloadSimple size={16} />
</HeaderButton>
{!mobile && (
<>
<HeaderButton onClick={onMinimize} title="Minimize">
<Minus size={16} />
</HeaderButton>
{isMaximized ? (
<HeaderButton onClick={onRestore} title="Restore">
<ArrowsIn size={16} />
</HeaderButton>
) : (
<HeaderButton onClick={onMaximize} title="Maximize">
<ArrowsOut size={16} />
</HeaderButton>
)}
</>
)}
<HeaderButton onClick={onClose} title="Close">
<X size={16} />
</HeaderButton>
</div>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More