Compare commits

..

2 Commits

Author SHA1 Message Date
Zamil Majdy
f5e2eccda7 dx(orchestrate): fix stale-review gate and add pr-test evaluation rules to SKILL.md (#12701)
## Changes

### verify-complete.sh
- CHANGES_REQUESTED reviews are now compared against the latest commit
timestamp. If the review was submitted **before** the latest commit, it
is treated as stale and does not block verification.
- Added fail-closed guard: if the `gh pr view` fetch fails, the script
exits 1 (rather than treating missing data as "no blocking reviews")
- Fixed edge case: a `CHANGES_REQUESTED` review with a null
`submittedAt` is now counted as fresh/blocking (previously silently
skipped)
- Combined two separate `gh pr view` calls into one (`--json
commits,reviews`) to reduce API calls and ensure consistency

### SKILL.md (orchestrate skill)
- Added `### /pr-test result evaluation` section with explicit
pass/partial/fail handling table
- **PARTIAL on any headline feature scenario = immediate blocker**:
re-brief the agent, fix, and re-run from scratch. Never approve or
output ORCHESTRATOR:DONE with a PARTIAL headline result.
- Concrete incident callout: PR #12699 S5 (Apply suggestions) was
PARTIAL — AI never output JSON action blocks — but was nearly approved.
This rule prevents recurrence.
- Updated `verify-complete.sh` description throughout to include "no
fresh CHANGES_REQUESTED"
- Added staleness rule documentation: a review only blocks if submitted
*after* the latest commit

## Why

Two separate incidents prompted these changes:

1. **verify-complete.sh false positive**: An automated bot
(autogpt-pr-reviewer) submitted a `CHANGES_REQUESTED` review in April.
An agent then pushed fixing commits. The old script still blocked on the
stale review, preventing the PR from being verified as done.

2. **Missed PARTIAL signal**: PR #12699 had a PARTIAL result on its
headline scenario (S5 Apply button) because the AI emitted direct
builder tool calls instead of JSON action blocks. The orchestrator
nearly approved it. The new SKILL.md rule makes PARTIAL = blocker
explicit.

## Checklist

- [x] I have read the contribution guide
- [x] My changes follow the code style of this project  
- [x] Changes are limited to the scope of this PR (< 20% unrelated
changes)
- [x] All new and existing tests pass
2026-04-08 08:58:42 +07:00
Zamil Majdy
58b230ff5a dx: add /orchestrate skill — Claude Code agent fleet supervisor with spare worktree lifecycle (#12691)
### Why

When running multiple Claude Code agents in parallel worktrees, they
frequently get stuck: an agent exits and sits at a shell prompt, freezes
mid-task, or waits on an approval prompt with no human watching. Fixing
this currently requires manually checking each tmux window.

### What

Adds a `/orchestrate` skill — a meta-agent supervisor that manages a
fleet of Claude Code agents across tmux windows and spare worktrees. It
auto-discovers available worktrees, spawns agents, monitors them, kicks
idle/stuck ones, auto-approves safe confirmations, and recycles
worktrees on completion.

### How to use

**Prerequisites:**
- One tmux session already running (the skill adds windows to it; it
does not create a new session)
- Spare worktrees on `spare/N` branches (e.g. `AutoGPT3` on `spare/3`,
`AutoGPT7` on `spare/7`)

**Basic workflow:**

```
/orchestrate capacity     → see how many spare worktrees are free
/orchestrate start        → enter task list, agents spawn automatically
/orchestrate status       → check what's running
/orchestrate add          → add one more task to the next free worktree
/orchestrate stop         → mark inactive (agents finish current work)
/orchestrate poll         → one manual poll cycle (debug / on-demand)
```

**Worktree lifecycle:**
```text
spare/N branch → /orchestrate add → new window + feat/branch + claude running
                                              ↓
                                     ORCHESTRATOR:DONE
                                              ↓
                              kill window + git checkout spare/N
                                              ↓
                                     spare/N (free again)
```

Windows are always capped by worktree count — no creep.

### Changes

- `.claude/skills/orchestrate/SKILL.md` — skill definition with 5
subcommands, state file schema, spawn/recycle helpers, approval policy
- `.claude/skills/orchestrate/scripts/classify-pane.sh` — pane state
classifier: `idle` (shell foreground), `running` (non-shell),
`waiting_approval` (pattern match), `complete` (ORCHESTRATOR:DONE)
- `.claude/skills/orchestrate/scripts/poll-cycle.sh` — poll loop:
reads/updates state file atomically, outputs JSON action list, stuck
detection via output-hash sampling

**State detection:**

| State | Detection method |
|---|---|
| `idle` | `pane_current_command` is a shell (zsh/bash/fish) |
| `running` | `pane_current_command` is non-shell (claude/node) |
| `stuck` | pane hash unchanged for N consecutive polls |
| `waiting_approval` | pattern match on last 40 lines of pane output |
| `complete` | `ORCHESTRATOR:DONE` string present in pane output |

**Safety policy for auto-approvals:** git ops, package installs, tests,
docker compose → approve. `rm -rf` outside worktree, force push, `sudo`,
secrets → escalate to user.

State file lives at `~/.claude/orchestrator-state.json` (outside repo,
never committed).

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `classify-pane.sh`: idle shell → `idle`, running process →
`running`, `ORCHESTRATOR:DONE` → `complete`, approval prompt →
`waiting_approval`, nonexistent window → `error`
- [x] `poll-cycle.sh`: inactive state → `[]`, empty agents array → `[]`,
spare worktree discovery, stuck detection (3-poll hash cycle)
- [x] Real agent spawn in `autogpt1` tmux session — agent ran, output
`ORCHESTRATOR:DONE`, recycle verified
  - [x] Upfront JSON validation before `set -e`-guarded jq reads
- [x] Idle timer reset only on `idle → running` transition (not stuck),
preventing false stuck-detections
- [x] Classify fallback only triggers when output is empty (no
double-JSON on classify exit 1)
2026-04-08 00:18:32 +07:00
73 changed files with 1683 additions and 6522 deletions

View File

@@ -0,0 +1,545 @@
---
name: orchestrate
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
user-invocable: true
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
metadata:
author: autogpt-team
version: "6.0.0"
---
# Orchestrate — Agent Fleet Supervisor
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
## Scripts
```bash
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
STATE_FILE=~/.claude/orchestrator-state.json
```
| Script | Purpose |
|---|---|
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
| `status.sh` | Print fleet status + live pane commands |
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
| `classify-pane.sh WINDOW` | Classify one pane state |
## Supervision model
```
Orchestrating Claude (this Claude session — IS the supervisor)
└── Reads pane output, checks CI, intervenes with targeted guidance
run-loop.sh (separate tmux window, every 30s)
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
```
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
## Checkpoint protocol
Agents output checkpoints as they complete each required step:
```
CHECKPOINT:<step-name>
```
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
## Worktree lifecycle
```text
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
CHECKPOINT:<step> (as steps complete)
ORCHESTRATOR:DONE
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
state → "done", notify, window KEPT OPEN
user/orchestrator explicitly requests recycle
recycle-agent.sh → spare/N (free again)
```
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
**To resume a done or crashed session:**
```bash
# Resume by stored session ID (preferred — exact session, full context)
claude --resume SESSION_ID --permission-mode bypassPermissions
# Or resume most recent session in that worktree directory
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
```
**To manually recycle when ready:**
```bash
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
# Then update state:
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## State file (`~/.claude/orchestrator-state.json`)
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp``mv`).
```json
{
"active": true,
"tmux_session": "autogpt1",
"idle_threshold_seconds": 300,
"loop_window": "autogpt1:5",
"repo": "Significant-Gravitas/AutoGPT",
"discord_webhook": "https://discord.com/api/webhooks/...",
"last_poll_at": 0,
"agents": [
{
"window": "autogpt1:3",
"worktree": "AutoGPT6",
"worktree_path": "/path/to/AutoGPT6",
"spare_branch": "spare/6",
"branch": "feat/my-feature",
"objective": "Implement X and open a PR",
"pr_number": "12345",
"session_id": "550e8400-e29b-41d4-a716-446655440000",
"steps": ["pr-address", "pr-test"],
"checkpoints": ["pr-address"],
"state": "running",
"last_output_hash": "",
"last_seen_at": 0,
"spawned_at": 0,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}
]
}
```
Top-level optional fields:
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
Per-agent fields:
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
## Serial /pr-test rule
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
You (the orchestrating Claude) own the test queue:
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
4. Feed results back to the relevant agent via `tmux send-keys`:
```bash
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
5. Wait for CI to confirm green before marking the agent done.
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
## Session restore (tested and confirmed)
Agent sessions are saved to disk. To restore a closed or crashed session:
```bash
# If session_id is in state (preferred):
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
# If no session_id (use --continue for most recent session in that directory):
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
```
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
```bash
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
'(.agents[] | select(.window == $old)).window = $new' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## Intent → action mapping
Match the user's message to one of these intents:
| The user says something like… | What to do |
|---|---|
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
When the intent is ambiguous, show capacity first and ask what tasks to run.
## Spawning agents
### 1. Resolve tmux session
```bash
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
```
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
### 2. Show available capacity
```bash
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
```
### 3. Collect tasks from the user
For each task, gather:
- **objective** — what to do (e.g. "implement feature X and open a PR")
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
### 4. Spawn one agent per task
```bash
# Get ordered list of spare worktrees
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
# For each task, take the next spare line:
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
# With PR number and required steps:
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
# Without PR (new work):
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
```
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
```bash
# Derive repo from git remote (used by verify-complete.sh + supervisor)
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
jq -n \
--arg session "$SESSION" \
--arg repo "$REPO" \
--argjson threshold 300 \
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
> ~/.claude/orchestrator-state.json
```
Optionally add a Discord webhook for completion notifications:
```bash
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
### 5. Start the mechanical babysitter
```bash
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
### 6. Begin supervising directly in this conversation
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
## Adding an agent
Find the next spare worktree, then spawn and append to state — same as steps 24 above but for a single task. If no spare worktrees are available, tell the user.
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
### Poll loop mechanism
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
1. Start each poll with `run_in_background: true` + a sleep before the work:
```bash
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
# + similar for each active window
```
2. When the background job notifies you, read the pane output and take action.
3. Immediately schedule the next background poll — this keeps the loop alive.
4. Stop scheduling when all agents are done/escalated.
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
### Each poll: what to check
```bash
# 1. Read state
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
# 2. For each running/stuck/idle agent, capture pane
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
```
For each agent, decide:
| What you see | Action |
|---|---|
| Spinner / tools running | Do nothing — agent is working |
| Idle `` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
| Stuck in error loop | Send targeted fix with exact error + solution |
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
| `ORCHESTRATOR:DONE` in output | Run `verify-complete.sh` — if it fails, re-brief with specific reason |
### Strict ORCHESTRATOR:DONE gate
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
```bash
SKILLS_DIR=~/.claude/orchestrator/scripts
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
```
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
### Re-brief a stalled agent
```bash
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
## Self-recovery protocol (agents)
spawn-agent.sh automatically includes this instruction in every objective:
> If your context compacts and you lose track of what to do, run:
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
## Passing images and screenshots to agents
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
1. **Save the image to a temp file** with a stable path:
```bash
# If the user drags in a screenshot or you receive a file path:
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
```
2. **Reference the path in the objective string**:
```bash
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
```
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
---
## Orchestrator final evaluation (YOU decide, not the script)
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
```
- [ ] /pr-test PR #12636 — fix copilot retry logic
- [ ] /pr-test PR #12699 — builder chat panel
```
Run one at a time. Check off as you go.
```
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
```
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
```
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
Context: This PR implements <objective from state file>. Key files: <list>.
Please verify: <specific behaviors to check>.
```
Only one `/pr-test` at a time — they share ports and DB.
### /pr-test result evaluation
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
| `/pr-test` result | Action |
|---|---|
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
| Any headline scenario **FAIL** | Re-brief the agent immediately |
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
**When any headline scenario is PARTIAL or FAIL:**
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
2. Re-brief the agent with the specific scenario that failed and what was missing:
```bash
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
3. Set state back to `running`:
```bash
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
### 2. Do your own evaluation
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
4. **Check the PR description** — title, summary, test plan complete?
### 3. Decide
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
```bash
# Mark done after your positive evaluation:
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## When to stop the fleet
Stop the fleet (`active = false`) when **all** of the following are true:
| Check | How to verify |
|---|---|
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
| No agents are `escalated` without human review | If any are escalated, surface to user first |
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
```bash
# Graceful stop
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
```
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
## tmux send-keys pattern
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
```bash
# CORRECT — text then Enter separately
tmux send-keys -t "$WINDOW" "your long message here"
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# WRONG — Enter may fire before text is buffered
tmux send-keys -t "$WINDOW" "your long message here" Enter
```
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
---
## Protected worktrees
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
| Worktree | Protected branch | Why |
|---|---|---|
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
---
## Key rules
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
5. **Always `--permission-mode bypassPermissions`** on every spawn
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
7. **Atomic state writes** — always write to `.tmp` then `mv`
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
#
# Usage: capacity.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree of the current git repo.
#
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
set -euo pipefail
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
echo "=== Available (spare) worktrees ==="
if [ -n "$REPO_ROOT" ]; then
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
else
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
fi
if [ -z "$SPARE" ]; then
echo " (none)"
else
while IFS= read -r line; do
[ -z "$line" ] && continue
echo "$line"
done <<< "$SPARE"
fi
echo ""
echo "=== In-use worktrees ==="
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
"$STATE_FILE" 2>/dev/null || echo "")
if [ -n "$IN_USE" ]; then
echo "$IN_USE"
else
echo " (none)"
fi
else
echo " (no active state file)"
fi

View File

@@ -0,0 +1,85 @@
#!/usr/bin/env bash
# classify-pane.sh — Classify the current state of a tmux pane
#
# Usage: classify-pane.sh <tmux-target>
# tmux-target: e.g. "work:0", "work:1.0"
#
# Output (stdout): JSON object:
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
#
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
set -euo pipefail
TARGET="${1:-}"
if [ -z "$TARGET" ]; then
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
exit 1
fi
# Validate tmux target format: session:window or session:window.pane
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
exit 1
fi
# Check session exists (use %%:* to extract session name from session:window)
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
exit 1
fi
# Get the current foreground command in the pane
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
| grep -v '^[[:space:]]*$' || true)
# --- Check: explicit completion marker ---
# Must be on its own line (not buried in the objective text sent at spawn time).
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
exit 0
fi
# --- Check: Claude Code approval prompt patterns ---
LAST_40=$(echo "$CLEAN" | tail -40)
APPROVAL_PATTERNS=(
"Do you want to proceed"
"Do you want to make this"
"\\[y/n\\]"
"\\[Y/n\\]"
"\\[n/Y\\]"
"Proceed\\?"
"Allow this command"
"Run bash command"
"Allow bash"
"Would you like"
"Press enter to continue"
"Esc to cancel"
)
for pattern in "${APPROVAL_PATTERNS[@]}"; do
if echo "$LAST_40" | grep -qiE "$pattern"; then
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
exit 0
fi
done
# --- Check: shell prompt (claude has exited) ---
# If the foreground process is a shell (not claude/node), the agent has exited
case "$PANE_CMD" in
zsh|bash|fish|sh|dash|tcsh|ksh)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
exit 0
;;
esac
# Agent is still running (claude/node/python is the foreground process)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
exit 0

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
# find-spare.sh — list worktrees on spare/N branches (free to use)
#
# Usage: find-spare.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree containing the current git repo.
#
# Output (stdout): one line per available worktree: "PATH BRANCH"
# e.g.: /Users/me/Code/AutoGPT3 spare/3
set -euo pipefail
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
if [ -z "$REPO_ROOT" ]; then
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
exit 1
fi
git -C "$REPO_ROOT" worktree list --porcelain \
| awk '
/^worktree / { path = substr($0, 10) }
/^branch / { branch = substr($0, 8); print path " " branch }
' \
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
| sed 's|refs/heads/||'

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# notify.sh — send a fleet notification message
#
# Delivery order (first available wins):
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
# 2. macOS notification center — osascript (silent fail if unavailable)
# 3. Stdout only
#
# Usage: notify.sh MESSAGE
# Exit: always 0 (notification failure must not abort the caller)
MESSAGE="${1:-}"
[ -z "$MESSAGE" ] && exit 0
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# --- Resolve Discord webhook ---
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
fi
# --- Discord delivery ---
if [ -n "$WEBHOOK" ]; then
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
curl -s -X POST "$WEBHOOK" \
-H "Content-Type: application/json" \
-d "$PAYLOAD" > /dev/null 2>&1 || true
fi
# --- macOS notification center (silent if not macOS or osascript missing) ---
if command -v osascript &>/dev/null 2>&1; then
# Escape single quotes for AppleScript
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
fi
# Always print to stdout so run-loop.sh logs it
echo "$MESSAGE"
exit 0

View File

@@ -0,0 +1,257 @@
#!/usr/bin/env bash
# poll-cycle.sh — Single orchestrator poll cycle
#
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
# and outputs a JSON array of actions for Claude to take.
#
# Usage: poll-cycle.sh
# Output (stdout): JSON array of action objects
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
# "worktree": "...", "objective": "...", "reason": "..." }]
#
# The state file is updated in-place (atomic write via .tmp).
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
# Cross-platform md5: always outputs just the hex digest
md5_hash() {
if command -v md5sum &>/dev/null; then
md5sum | awk '{print $1}'
else
md5 | awk '{print $NF}'
fi
}
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
# Ensure state file exists
if [ ! -f "$STATE_FILE" ]; then
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
fi
# Validate JSON upfront before any jq reads that run under set -e.
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
# abort the script at the ACTIVE read below without emitting any JSON output.
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
echo "[]"
exit 0
fi
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
if [ "$ACTIVE" != "true" ]; then
echo "[]"
exit 0
fi
NOW=$(date +%s)
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
ACTIONS="[]"
UPDATED_AGENTS="[]"
# Read agents as newline-delimited JSON objects.
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
# so we suppress that exit code and separately validate the file is well-formed JSON.
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
fi
echo "[]"
exit 0
fi
if [ -z "$AGENTS_JSON" ]; then
echo "[]"
exit 0
fi
while IFS= read -r agent; do
[ -z "$agent" ] && continue
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
WINDOW=$(echo "$agent" | jq -r '.window // ""')
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
STATE=$(echo "$agent" | jq -r '.state // "running"')
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
# Validate window format to prevent tmux target injection.
# Allow session:window (numeric or named) and session:window.pane
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo "Skipping agent with invalid window value: $WINDOW" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Pass-through terminal-state agents
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Classify pane.
# classify-pane.sh always emits JSON before exit (even on error), so using
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
# Use "|| true" inside the substitution so set -euo pipefail does not abort
# the poll cycle when classify exits with a non-zero status code.
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
# --- Checkpoint tracking ---
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
if [ -n "$RAW" ]; then
FOUND_CPS=$(echo "$RAW" \
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
| sed 's/CHECKPOINT://' \
| sort -u \
| jq -R . | jq -s . 2>/dev/null || echo "[]")
NEW_CHECKPOINTS_JSON=$(jq -n \
--argjson existing "$EXISTING_CPS" \
--argjson found "$FOUND_CPS" \
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
fi
# Compute content hash for stuck-detection (only for running agents)
CURRENT_HASH=""
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
fi
NEW_STATE="$STATE"
NEW_IDLE_SINCE="$IDLE_SINCE"
NEW_REVISION_COUNT="$REVISION_COUNT"
ACTION="none"
REASON="$PANE_REASON"
case "$PANE_STATE" in
complete)
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
NEW_STATE="pending_evaluation"
ACTION="complete" # run-loop logs it but takes no action
;;
waiting_approval)
NEW_STATE="waiting_approval"
ACTION="approve"
;;
idle)
# Agent process has exited — needs restart
NEW_STATE="idle"
ACTION="kick"
REASON="agent exited (shell is foreground)"
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
fi
;;
running)
# Clear idle_since only when transitioning from idle (agent was kicked and
# restarted). Do NOT reset for stuck — idle_since must persist across polls
# so STUCK_DURATION can accumulate and trigger escalation.
# Also update the local IDLE_SINCE so the hash-stability check below uses
# the reset value on this same poll, not the stale kick timestamp.
if [[ "$STATE" == "idle" ]]; then
NEW_IDLE_SINCE=0
IDLE_SINCE=0
fi
# Check if hash has been stable (agent may be stuck mid-task)
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
NEW_IDLE_SINCE=$NOW
else
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
else
NEW_STATE="stuck"
ACTION="kick"
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
fi
fi
fi
else
# Only reset the idle timer when we have a valid hash comparison (pane
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
# preserve existing timers so stuck detection is not inadvertently reset.
if [ -n "$CURRENT_HASH" ]; then
NEW_STATE="running"
NEW_IDLE_SINCE=0
fi
fi
;;
error)
REASON="classify error: $PANE_REASON"
;;
esac
# Build updated agent record (ensure idle_since and revision_count are numeric)
# Use || true on each jq call so a malformed field skips this agent rather than
# aborting the entire poll cycle under set -e.
UPDATED_AGENT=$(echo "$agent" | jq \
--arg state "$NEW_STATE" \
--arg hash "$CURRENT_HASH" \
--argjson now "$NOW" \
--arg idle_since "$NEW_IDLE_SINCE" \
--arg revision_count "$NEW_REVISION_COUNT" \
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
'.state = $state
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
| .last_seen_at = $now
| .idle_since = ($idle_since | tonumber)
| .revision_count = ($revision_count | tonumber)
| .checkpoints = $checkpoints' 2>/dev/null) || {
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
}
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
# Add action if needed
if [ "$ACTION" != "none" ]; then
ACTION_OBJ=$(jq -n \
--arg window "$WINDOW" \
--arg action "$ACTION" \
--arg state "$NEW_STATE" \
--arg worktree "$WORKTREE" \
--arg objective "$OBJECTIVE" \
--arg reason "$REASON" \
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
fi
done <<< "$AGENTS_JSON"
# Atomic state file update
jq --argjson agents "$UPDATED_AGENTS" \
--argjson now "$NOW" \
'.agents = $agents | .last_poll_at = $now' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
echo "$ACTIONS"

View File

@@ -0,0 +1,32 @@
#!/usr/bin/env bash
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
#
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
# WINDOW — tmux target, e.g. autogpt1:3
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — branch to restore, e.g. spare/6
#
# Stdout: one status line
set -euo pipefail
if [ $# -lt 3 ]; then
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
exit 1
fi
WINDOW="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
# Kill the tmux window (ignore error — may already be gone)
tmux kill-window -t "$WINDOW" 2>/dev/null || true
# Restore to spare branch: abort any in-progress operation, then clean
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
echo "Recycled: $(basename "$WORKTREE_PATH")$SPARE_BRANCH (window $WINDOW closed)"

View File

@@ -0,0 +1,164 @@
#!/usr/bin/env bash
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
#
# Handles ONLY two things that need no intelligence:
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
#
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
# marking done, deciding to close windows — is the orchestrating Claude's job.
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
# the orchestrator's background poll loop handles it from there.
#
# Usage: run-loop.sh
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
set -euo pipefail
# Copy scripts to a stable location outside the repo so they survive branch
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
# current branch).
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
mkdir -p "$STABLE_SCRIPTS_DIR"
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
POLL_INTERVAL="${POLL_INTERVAL:-30}"
# ---------------------------------------------------------------------------
# update_state WINDOW FIELD VALUE
# ---------------------------------------------------------------------------
update_state() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --arg v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
update_state_int() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
agent_field() {
jq -r --arg w "$1" --arg f "$2" \
'.agents[] | select(.window == $w) | .[$f] // ""' \
"$STATE_FILE" 2>/dev/null
}
# ---------------------------------------------------------------------------
# wait_for_prompt WINDOW — wait up to 60s for Claude's prompt
# ---------------------------------------------------------------------------
wait_for_prompt() {
local window="$1"
for i in $(seq 1 60); do
local cmd pane
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
if echo "$pane" | grep -q "Enter to confirm"; then
tmux send-keys -t "$window" Down Enter; sleep 2; continue
fi
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "" && return 0
sleep 1
done
return 1 # timed out
}
# ---------------------------------------------------------------------------
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
# ---------------------------------------------------------------------------
handle_kick() {
local window="$1" state="$2"
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
local worktree_path session_id
worktree_path=$(agent_field "$window" "worktree_path")
session_id=$(agent_field "$window" "session_id")
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
# Resume the exact session so the agent retains full context — no need to re-send objective
if [ -n "$session_id" ]; then
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
else
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
fi
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for "
}
# ---------------------------------------------------------------------------
# handle_approve WINDOW — auto-approve dialogs that need no judgment
# ---------------------------------------------------------------------------
handle_approve() {
local window="$1"
local pane_tail
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
# Settings error dialog at startup
if echo "$pane_tail" | grep -q "Enter to confirm"; then
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
tmux send-keys -t "$window" Down Enter
return
fi
# Numbered-option dialog (e.g. "Do you want to make this edit?")
# is already on option 1 (Yes) — Enter confirms it
if echo "$pane_tail" | grep -qE "\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
tmux send-keys -t "$window" "" Enter
return
fi
# y/n prompt for safe operations
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
tmux send-keys -t "$window" "y" Enter
return
fi
# Anything else — supervisor handles it, just log
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
}
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll every ${POLL_INTERVAL}s)"
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
echo "---"
while true; do
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
echo "[$(date +%H:%M:%S)] active=false — exiting."
exit 0
fi
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
KICKED=0; DONE=0
while IFS= read -r action; do
[ -z "$action" ] && continue
WINDOW=$(echo "$action" | jq -r '.window // ""')
ACTION=$(echo "$action" | jq -r '.action // ""')
STATE=$(echo "$action" | jq -r '.state // ""')
case "$ACTION" in
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
approve) handle_approve "$WINDOW" || true ;;
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
esac
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
"$STATE_FILE" 2>/dev/null || echo 0)
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled"
sleep "$POLL_INTERVAL"
done

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env bash
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
#
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
# SESSION — tmux session name, e.g. autogpt1
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
# OBJECTIVE — task description sent to the agent
# PR_NUMBER — (optional) GitHub PR number for completion verification
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
#
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
# Exit non-zero on failure.
set -euo pipefail
if [ $# -lt 5 ]; then
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
exit 1
fi
SESSION="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
NEW_BRANCH="$4"
OBJECTIVE="$5"
PR_NUMBER="${6:-}"
STEPS=("${@:7}")
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# Generate a stable session ID so this agent's Claude session can always be resumed:
# claude --resume $SESSION_ID --permission-mode bypassPermissions
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
# Create (or switch to) the task branch
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
# Open a new named tmux window; capture its numeric index
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
WINDOW="${SESSION}:${WIN_IDX}"
# Append the initial agent record to the state file so subsequent jq updates find it.
# This must happen before the pr_number/steps update below.
if [ -f "$STATE_FILE" ]; then
NOW=$(date +%s)
jq --arg window "$WINDOW" \
--arg worktree "$WORKTREE_NAME" \
--arg worktree_path "$WORKTREE_PATH" \
--arg spare_branch "$SPARE_BRANCH" \
--arg branch "$NEW_BRANCH" \
--arg objective "$OBJECTIVE" \
--arg session_id "$SESSION_ID" \
--argjson now "$NOW" \
'.agents += [{
"window": $window,
"worktree": $worktree,
"worktree_path": $worktree_path,
"spare_branch": $spare_branch,
"branch": $branch,
"objective": $objective,
"session_id": $session_id,
"state": "running",
"checkpoints": [],
"last_output_hash": "",
"last_seen_at": $now,
"spawned_at": $now,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
# The agent record was appended above so the jq select now finds it.
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
if [ "${#STEPS[@]}" -gt 0 ]; then
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
else
STEPS_JSON='[]'
fi
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Launch claude with a stable session ID so it can always be resumed after a crash:
# claude --resume SESSION_ID --permission-mode bypassPermissions
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
# Wait up to 60s for claude to be fully interactive:
# both pane_current_command == 'node' AND the '' prompt is visible.
PROMPT_FOUND=false
for i in $(seq 1 60); do
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "")
PANE=$(tmux capture-pane -t "$WINDOW" -p 2>/dev/null || echo "")
if echo "$PANE" | grep -q "Enter to confirm"; then
tmux send-keys -t "$WINDOW" Down Enter
sleep 2
continue
fi
if [[ "$CMD" == "node" ]] && echo "$PANE" | grep -q ""; then
PROMPT_FOUND=true
break
fi
sleep 1
done
if ! $PROMPT_FOUND; then
echo "[spawn-agent] WARNING: timed out waiting for prompt on $WINDOW — sending objective anyway" >&2
fi
# Send the task. Split text and Enter — if combined, Enter can fire before the string
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# Only output the window address — nothing else (callers parse this)
echo "$WINDOW"

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# status.sh — print orchestrator status: state file summary + live tmux pane commands
#
# Usage: status.sh
# Reads: ~/.claude/orchestrator-state.json
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "No orchestrator state found at $STATE_FILE"
exit 0
fi
# Header: active status, session, thresholds, last poll
jq -r '
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
""
' "$STATE_FILE"
# Each agent: state, window, worktree/branch, truncated objective
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
if [ "$AGENT_COUNT" -eq 0 ]; then
echo " (no agents registered)"
else
jq -r '
.agents[] |
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
" \(.objective // "" | .[0:70])"
' "$STATE_FILE"
fi
echo ""
# Live pane_current_command for non-done agents
while IFS= read -r WINDOW; do
[ -z "$WINDOW" ] && continue
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
echo " $WINDOW live: $CMD"
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env bash
# verify-complete.sh — verify a PR task is truly done before marking the agent done
#
# Check order matters:
# 1. Checkpoints — did the agent do all required steps?
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
# 3. CI passing — no failures (agent must fix before done)
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
#
# Usage: verify-complete.sh WINDOW
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
set -euo pipefail
WINDOW="$1"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
# No PR number = cannot verify
if [ -z "$PR_NUMBER" ]; then
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
exit 1
fi
# --- Check 1: all required steps are checkpointed ---
MISSING=""
while IFS= read -r step; do
[ -z "$step" ] && continue
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
MISSING="$MISSING $step"
fi
done <<< "$STEPS"
if [ -n "$MISSING" ]; then
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
exit 1
fi
# Resolve repo for all GitHub checks below
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
fi
if [ -z "$REPO" ]; then
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
exit 0
fi
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
# --- Check 2: CI fully complete — no pending checks ---
# Pending checks MUST finish before we check threads/reviews:
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
if [ "$PENDING" -gt 0 ]; then
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
exit 1
fi
# --- Check 3: CI passing — no failures ---
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
if [ "$FAILING" -gt 0 ]; then
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
exit 1
fi
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
if [ -n "$LATEST_RUN_AT" ]; then
if date --version >/dev/null 2>&1; then
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
else
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
fi
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
exit 1
fi
fi
fi
OWNER=$(echo "$REPO" | cut -d/ -f1)
REPONAME=$(echo "$REPO" | cut -d/ -f2)
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
UNRESOLVED=$(gh api graphql -f query="
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
pullRequest(number: ${PR_NUMBER}) {
reviewThreads(first: 50) { nodes { isResolved } }
}
}
}
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
if [ "$UNRESOLVED" -gt 0 ]; then
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
exit 1
fi
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
# Stale reviews (pre-dating the fixing commits) should not block verification.
#
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
# treat that as NOT COMPLETE rather than silently passing.
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
# PR activity (bot comments, label changes) which would create false negatives.
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
--json commits,latestReviews 2>/dev/null) || {
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
exit 1
}
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
BLOCKING_CHANGES_REQUESTED=0
BLOCKING_REQUESTERS=""
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
if date --version >/dev/null 2>&1; then
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
else
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
fi
while IFS= read -r review; do
[ -z "$review" ] && continue
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
if [ -z "$REVIEW_DATE" ]; then
# No submission date — treat as fresh (conservative: blocks verification)
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
else
if date --version >/dev/null 2>&1; then
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
else
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
fi
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
# Review was submitted AFTER latest commit — still fresh, blocks verification
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
fi
# Review submitted BEFORE latest commit — stale, skip
fi
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
else
# No commit date or no changes_requested — check raw count as fallback
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
fi
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
exit 1
fi
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
exit 0

View File

@@ -90,10 +90,12 @@ Address comments **one at a time**: fix → commit → push → inline reply →
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
Use a **markdown commit link** so GitHub renders it as a clickable reference. Get the full SHA with `git rev-parse HEAD` after committing:
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
## Codecov coverage

View File

@@ -530,19 +530,9 @@ After showing all screenshots, output a **detailed** summary table:
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
# plain variable with a lookup function instead.
declare -A SCREENSHOT_EXPLANATIONS=(
# Each explanation MUST answer three things:
# 1. FLOW: Which test scenario / user journey is this part of?
# 2. STEPS: What exact actions were taken to reach this state?
# 3. EVIDENCE: What does this screenshot prove (pass/fail/data)?
#
# Good example:
# ["03-cost-log-after-run.png"]="Flow: LLM block cost tracking. Steps: Logged in as tester@gmail.com → ran 'Cost Test Agent' → waited for COMPLETED status. Evidence: PlatformCostLog table shows 1 new row with cost_microdollars=1234 and correct user_id."
#
# Bad example (too vague — never do this):
# ["03-cost-log.png"]="Shows the cost log table."
["01-login-page.png"]="Flow: Login flow. Steps: Opened /login. Evidence: Login page renders with email/password fields and SSO options visible."
["02-builder-with-block.png"]="Flow: Block execution. Steps: Logged in → /build → added LLM block. Evidence: Builder canvas shows block connected to trigger, ready to run."
# ... one entry per screenshot using the flow/steps/evidence format above
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
# ... one entry per screenshot, using the same explanations you showed the user above
)
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
@@ -557,8 +547,7 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
> **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.**
> Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
@@ -595,11 +584,11 @@ for img in "${SCREENSHOT_FILES[@]}"; do
done
TREE_JSON+=']'
# Step 2: Create tree, commit (with parent), and branch ref
# Step 2: Create tree, commit, and branch ref
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
# Resolve existing branch tip as parent (avoids orphan commits on repeat runs)
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || true)
# Resolve parent commit so screenshots are chained, not orphan root commits
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
if [ -n "$PARENT_SHA" ]; then
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
@@ -607,7 +596,6 @@ if [ -n "$PARENT_SHA" ]; then
-f "parents[]=$PARENT_SHA" \
--jq '.sha')
else
# First commit on this branch — no parent
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
@@ -618,7 +606,7 @@ gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
-X PATCH -f sha="$COMMIT_SHA" -f force=true
-X PATCH -f sha="$COMMIT_SHA" -F force=true
```
Then post the comment with **inline images AND explanations for each screenshot**:
@@ -682,122 +670,122 @@ ${IMAGE_MARKDOWN}
${FAILED_SECTION}
INNEREOF
POSTED_BODY=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" --jq '.body')
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
rm -f "$COMMENT_FILE"
# Verify the posted comment contains inline images — exit 1 if none found
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
exit 1
fi
echo "✓ Inline images verified in posted comment"
```
**The PR comment MUST include:**
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
3. A structured explanation below each screenshot covering: **Flow** (which scenario), **Steps** (exact actions taken to reach this state), **Evidence** (what this proves — pass/fail/data values). A bare "shows the page" caption is not acceptable.
3. A 1-2 sentence explanation below each screenshot describing what it proves
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
**Verify inline rendering after posting — this is required, not optional:**
## Step 8: Evaluate and post a formal PR review
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
### Evaluation criteria
Re-read the PR description:
```bash
# 1. Confirm the posted comment body contains inline image markdown syntax
if ! echo "$POSTED_BODY" | grep -q '!\['; then
echo "❌ FAIL: No inline image tags in posted comment body. Re-check IMAGE_MARKDOWN and re-post."
exit 1
fi
# 2. Verify at least one raw URL actually resolves (catches wrong branch name, wrong path, etc.)
FIRST_IMG_URL=$(echo "$POSTED_BODY" | grep -o 'https://raw.githubusercontent.com[^)]*' | head -1)
if [ -n "$FIRST_IMG_URL" ]; then
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" --max-time 10 "$FIRST_IMG_URL")
if [ "$HTTP_STATUS" = "200" ]; then
echo "✅ Inline images confirmed and raw URL resolves (HTTP 200)"
else
echo "❌ FAIL: Raw image URL returned HTTP $HTTP_STATUS — images will not render inline."
echo " URL: $FIRST_IMG_URL"
echo " Check branch name, path, and that the push succeeded."
exit 1
fi
else
echo "⚠️ Could not extract a raw URL from the comment — verify manually."
fi
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
```
## Step 8: Evaluate test completeness and post a GitHub review
Score the run against each criterion:
After posting the PR comment, evaluate whether the test run actually covered everything it needed to. This is NOT a rubber-stamp — be critical. Then post a formal GitHub review so the PR author and reviewers can see the verdict.
| Criterion | Pass condition |
|-----------|---------------|
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
| **All scenarios pass** | No FAIL rows in the results table |
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
| **Before/after evidence** | Every state-changing API call has before/after values logged |
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
| **No regressions** | Existing core flows (login, agent create/run) still work |
### 8a. Evaluate against the test plan
### Decision logic
Re-read `$RESULTS_DIR/test-plan.md` (written in Step 2) and `$RESULTS_DIR/test-report.md` (written in Step 5). For each scenario in the plan, answer:
```
ALL criteria pass → APPROVE
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
```
> **Note:** `test-report.md` is written in Step 5. If it doesn't exist, write it before proceeding here — see the Step 5 template. Do not skip evaluation because the file is missing; create it from your notes instead.
| Question | Pass criteria |
|----------|--------------|
| Was it tested? | Explicit steps were executed, not just described |
| Is there screenshot evidence? | At least one before/after screenshot per scenario |
| Did the core feature work correctly? | Expected state matches actual state |
| Were negative cases tested? | At least one failure/rejection case per feature |
| Was DB/API state verified (not just UI)? | Raw API response or DB query confirms state change |
Build a verdict:
- **APPROVE** — every scenario tested, evidence present, no bugs found or all bugs are minor/known
- **REQUEST_CHANGES** — one or more: untested scenarios, missing evidence, bugs found, data not verified
### 8b. Post the GitHub review
### Post the review
```bash
EVAL_FILE=$(mktemp)
REVIEW_FILE=$(mktemp)
# === STEP A: Write header ===
cat > "$EVAL_FILE" << 'ENDEVAL'
## 🧪 Test Evaluation
# Count results
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
### Coverage checklist
ENDEVAL
# List any coverage gaps found during evaluation (populate this array as you assess)
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
COVERAGE_GAPS=()
```
# === STEP B: Append ONE line per scenario — do this BEFORE calculating verdict ===
# Format: "- ✅ **Scenario N name**: <what was done and verified>"
# or "- ❌ **Scenario N name**: <what is missing or broken>"
# Examples:
# echo "- ✅ **Scenario 1 Login flow**: tested, screenshot evidence present, auth token verified via API" >> "$EVAL_FILE"
# echo "- ❌ **Scenario 3 Cost logging**: NOT verified in DB — UI showed entry but raw SQL query was skipped" >> "$EVAL_FILE"
#
# !!! IMPORTANT: append ALL scenario lines here before proceeding to STEP C !!!
**If APPROVING** — all criteria met, zero failures, full coverage:
# === STEP C: Derive verdict from the checklist — runs AFTER all lines are appended ===
FAIL_COUNT=$(grep -c "^- ❌" "$EVAL_FILE" || true)
if [ "$FAIL_COUNT" -eq 0 ]; then
VERDICT="APPROVE"
else
VERDICT="REQUEST_CHANGES"
fi
```bash
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — APPROVED
# === STEP D: Append verdict section ===
cat >> "$EVAL_FILE" << ENDVERDICT
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
### Verdict
ENDVERDICT
**Coverage:** All features described in the PR were exercised.
if [ "$VERDICT" = "APPROVE" ]; then
echo "✅ All scenarios covered with evidence. No blocking issues found." >> "$EVAL_FILE"
else
echo "$FAIL_COUNT scenario(s) incomplete or have confirmed bugs. See ❌ items above." >> "$EVAL_FILE"
echo "" >> "$EVAL_FILE"
echo "**Required before merge:** address each ❌ item above." >> "$EVAL_FILE"
fi
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
# === STEP E: Post the review ===
gh api "repos/${REPO}/pulls/$PR_NUMBER/reviews" \
--method POST \
-f body="$(cat "$EVAL_FILE")" \
-f event="$VERDICT"
**Negative tests:** Failure paths tested for each feature.
rm -f "$EVAL_FILE"
No regressions observed on core flows.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
echo "✅ PR approved"
```
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
```bash
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — Changes Requested
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
### Required before merge
${FAIL_LIST}
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
Please fix the above and re-run the E2E tests.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
echo "❌ Changes requested"
```
```bash
rm -f "$REVIEW_FILE"
```
**Rules:**
- Never auto-approve without checking every scenario in the test plan
- `REQUEST_CHANGES` if ANY scenario is untested, lacks DB/API evidence, or has a confirmed bug
- The evaluation body must list every scenario explicitly (✅ or ❌) — not just the failures
- If you find new bugs during evaluation, add them to the request-changes body and (if `--fix` flag is set) fix them before posting
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
- Never request changes for issues already fixed in this run
## Fix mode (--fix flag)

View File

@@ -1,98 +0,0 @@
import logging
from datetime import datetime
from autogpt_libs.auth import get_user_id, requires_admin_user
from cachetools import TTLCache
from fastapi import APIRouter, Query, Security
from pydantic import BaseModel
from backend.data.platform_cost import (
CostLogRow,
PlatformCostDashboard,
get_platform_cost_dashboard,
get_platform_cost_logs,
)
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
# Cache dashboard results for 30 seconds per unique filter combination.
# The table is append-only so stale reads are acceptable for analytics.
_DASHBOARD_CACHE_TTL = 30
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
)
router = APIRouter(
prefix="/platform-costs",
tags=["platform-cost", "admin"],
dependencies=[Security(requires_admin_user)],
)
class PlatformCostLogsResponse(BaseModel):
logs: list[CostLogRow]
pagination: Pagination
@router.get(
"/dashboard",
response_model=PlatformCostDashboard,
summary="Get Platform Cost Dashboard",
)
async def get_cost_dashboard(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
cache_key = (start, end, provider, user_id)
cached = _dashboard_cache.get(cache_key)
if cached is not None:
return cached
result = await get_platform_cost_dashboard(
start=start,
end=end,
provider=provider,
user_id=user_id,
)
_dashboard_cache[cache_key] = result
return result
@router.get(
"/logs",
response_model=PlatformCostLogsResponse,
summary="Get Platform Cost Logs",
)
async def get_cost_logs(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
start=start,
end=end,
provider=provider,
user_id=user_id,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
logs=logs,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -1,192 +0,0 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.platform_cost import PlatformCostDashboard
from . import platform_cost_routes
from .platform_cost_routes import router as platform_cost_router
app = fastapi.FastAPI()
app.include_router(platform_cost_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
# Clear TTL cache so each test starts cold.
platform_cost_routes._dashboard_cache.clear()
yield
app.dependency_overrides.clear()
def test_get_dashboard_success(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
response = client.get("/platform-costs/dashboard")
assert response.status_code == 200
data = response.json()
assert "by_provider" in data
assert "by_user" in data
assert data["total_cost_microdollars"] == 0
def test_get_logs_success(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get("/platform-costs/logs")
assert response.status_code == 200
data = response.json()
assert data["logs"] == []
assert data["pagination"]["total_items"] == 0
def test_get_dashboard_with_filters(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mock_dashboard = AsyncMock(return_value=real_dashboard)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
mock_dashboard,
)
response = client.get(
"/platform-costs/dashboard",
params={
"start": "2026-01-01T00:00:00",
"end": "2026-04-01T00:00:00",
"provider": "openai",
"user_id": "test-user-123",
},
)
assert response.status_code == 200
mock_dashboard.assert_called_once()
call_kwargs = mock_dashboard.call_args.kwargs
assert call_kwargs["provider"] == "openai"
assert call_kwargs["user_id"] == "test-user-123"
assert call_kwargs["start"] is not None
assert call_kwargs["end"] is not None
def test_get_logs_with_pagination(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get(
"/platform-costs/logs",
params={"page": 2, "page_size": 25, "provider": "anthropic"},
)
assert response.status_code == 200
data = response.json()
assert data["pagination"]["current_page"] == 2
assert data["pagination"]["page_size"] == 25
def test_get_dashboard_requires_admin() -> None:
import fastapi
from fastapi import HTTPException
def reject_jwt(request: fastapi.Request):
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_jwt_payload] = reject_jwt
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 401
response = client.get("/platform-costs/logs")
assert response.status_code == 401
finally:
app.dependency_overrides.clear()
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 403
response = client.get("/platform-costs/logs")
assert response.status_code == 403
finally:
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
def test_get_logs_invalid_page_size_too_large() -> None:
"""page_size > 200 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 201})
assert response.status_code == 422
def test_get_logs_invalid_page_size_zero() -> None:
"""page_size = 0 (below ge=1) must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 0})
assert response.status_code == 422
def test_get_logs_invalid_page_negative() -> None:
"""page < 1 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page": 0})
assert response.status_code == 422
def test_get_dashboard_invalid_date_format() -> None:
"""Malformed start date must be rejected with 422."""
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
assert response.status_code == 422
def test_get_dashboard_cache_hit(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Second identical request returns cached result without calling the DB again."""
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=42,
total_requests=1,
total_users=1,
)
mock_fn = mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
client.get("/platform-costs/dashboard")
client.get("/platform-costs/dashboard")
mock_fn.assert_awaited_once() # second request hit the cache

View File

@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
@@ -330,11 +329,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.admin.platform_cost_routes.router,
tags=["v2", "admin"],
prefix="/api/admin",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchOrganizationsBlock(Block):
@@ -218,11 +218,6 @@ To find IDs, identify the values for organization_id when you call this endpoint
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
organizations = await self.search_organizations(query, credentials)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(organizations)), provider_cost_type="items"
)
)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchPeopleBlock(Block):
@@ -366,9 +366,4 @@ class SearchPeopleBlock(Block):
*(enrich_or_fallback(person) for person in people)
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(people)), provider_cost_type="items"
)
)
yield "people", people

View File

@@ -1,712 +0,0 @@
"""Unit tests for merge_stats cost tracking in individual blocks.
Covers the exa code_context, exa contents, and apollo organization blocks
to verify provider cost is correctly extracted and reported.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, NodeExecutionStats
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TEST_EXA_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_EXA_CREDENTIALS_INPUT = {
"provider": TEST_EXA_CREDENTIALS.provider,
"id": TEST_EXA_CREDENTIALS.id,
"type": TEST_EXA_CREDENTIALS.type,
"title": TEST_EXA_CREDENTIALS.title,
}
# ---------------------------------------------------------------------------
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
# ---------------------------------------------------------------------------
class TestExaCodeContextBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_float_cost(self):
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-1",
"query": "how to use hooks",
"response": "Here are some examples...",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="how to use hooks",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_dollars_does_not_raise(self):
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-2",
"query": "query",
"response": "response",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
merge_calls: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert merge_calls == []
@pytest.mark.asyncio
async def test_zero_cost_is_tracked(self):
"""A zero cost_dollars string '0.0' should still be recorded."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-3",
"query": "query",
"response": "...",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
# ---------------------------------------------------------------------------
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
# ---------------------------------------------------------------------------
class TestExaContentsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_cost_dollars_total(self):
"""provider_cost equals response.cost_dollars.total when present."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
cost_dollars = CostDollars(total=0.012)
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = cost_dollars
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_merge_stats_when_cost_dollars_absent(self):
"""When response.cost_dollars is None, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = None
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
# ---------------------------------------------------------------------------
class TestSearchOrganizationsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_org_count(self):
"""provider_cost == number of returned organizations, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Organization
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=fake_orgs,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=APOLLO_CREDS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(3.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_org_list_tracks_zero(self):
"""An empty organization list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=APOLLO_CREDS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# JinaEmbeddingBlock — token count from usage.total_tokens
# ---------------------------------------------------------------------------
class TestJinaEmbeddingBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_token_count(self):
"""provider token count is recorded when API returns usage.total_tokens."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
"usage": {"total_tokens": 42},
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello world"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].input_token_count == 42
@pytest.mark.asyncio
async def test_no_merge_stats_when_usage_absent(self):
"""When API response omits usage field, merge_stats is not called."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# UnrealTextToSpeechBlock — character count from input text length
# ---------------------------------------------------------------------------
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
test_text = "Hello, world!"
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text=test_text,
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text="",
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------
# GoogleMapsSearchBlock — item count from search_places results
# ---------------------------------------------------------------------------
class TestGoogleMapsSearchBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_place_count(self):
"""provider_cost equals number of returned places, type == 'items'."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=fake_places,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="coffee shops",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 4.0
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_results_tracks_zero(self):
"""Zero places returned results in provider_cost=0.0."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="nothing here",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SmartLeadAddLeadsBlock — item count from lead_list length
# ---------------------------------------------------------------------------
class TestSmartLeadAddLeadsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_lead_count(self):
"""provider_cost equals number of leads uploaded, type == 'items'."""
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
)
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
from backend.blocks.smartlead.models import (
AddLeadsToCampaignResponse,
LeadInput,
)
block = AddLeadToCampaignBlock()
fake_leads = [
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
]
fake_response = AddLeadsToCampaignResponse(
ok=True,
upload_count=2,
total_leads=2,
block_count=0,
duplicate_count=0,
invalid_email_count=0,
invalid_emails=[],
already_added_to_campaign=0,
unsubscribed_leads=[],
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
bounce_count=0,
)
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
AddLeadToCampaignBlock,
"add_leads_to_campaign",
new_callable=AsyncMock,
return_value=fake_response,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = AddLeadToCampaignBlock.Input(
campaign_id=123,
lead_list=fake_leads,
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=SL_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 2.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SearchPeopleBlock — item count from people list length
# ---------------------------------------------------------------------------
class TestSearchPeopleBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_people_count(self):
"""provider_cost equals number of returned people, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Contact
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=fake_people,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(5.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_people_list_tracks_zero(self):
"""An empty people list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"

View File

@@ -9,7 +9,6 @@ from typing import Union
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -117,10 +116,3 @@ class ExaCodeContextBlock(Block):
yield "cost_dollars", context.cost_dollars
yield "search_time", context.search_time
yield "output_tokens", context.output_tokens
# Parse cost_dollars (API returns as string, e.g. "0.005")
try:
cost_usd = float(context.cost_dollars)
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
except (ValueError, TypeError):
pass

View File

@@ -4,7 +4,6 @@ from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -224,6 +223,3 @@ class ExaContentsBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -1,575 +0,0 @@
"""Tests for cost tracking in Exa blocks.
Covers the cost_dollars → provider_cost → merge_stats path for both
ExaContentsBlock and ExaCodeContextBlock.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import NodeExecutionStats
class TestExaCodeContextCostTracking:
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
@pytest.mark.asyncio
async def test_valid_cost_string_is_parsed_and_merged(self):
"""A numeric cost string like '0.005' is merged as provider_cost."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-1",
"query": "test query",
"response": "some code",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
assert any(k == "cost_dollars" for k, _ in outputs)
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_string_does_not_raise(self):
"""A non-numeric cost_dollars value is swallowed silently."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-2",
"query": "test",
"response": "code",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
# No merge_stats call because float() raised ValueError
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_string_is_merged(self):
"""'0.0' is a valid cost — should still be tracked."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-3",
"query": "free query",
"response": "result",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
async for _ in block.run(
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaContentsCostTracking:
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_dollars_is_merged(self):
"""A total of 0.0 (free tier) should still be merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaSearchCostTracking:
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.008)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
class TestExaSimilarCostTracking:
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-1"
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.015)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-2"
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
# ---------------------------------------------------------------------------
# ExaCreateResearchBlock — cost_dollars from completed poll response
# ---------------------------------------------------------------------------
COMPLETED_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "completed",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
"finishedAt": 1700000060000,
"costDollars": {
"total": 0.05,
"numSearches": 3,
"numPages": 10,
"reasoningTokens": 500,
},
"output": {"content": "Research findings...", "parsed": None},
}
PENDING_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "pending",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
}
class TestExaCreateResearchBlockCostTracking:
"""ExaCreateResearchBlock merges cost from completed poll response."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total when poll returns completed."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed response has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaGetResearchBlock — cost_dollars from single GET response
# ---------------------------------------------------------------------------
class TestExaGetResearchBlockCostTracking:
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_from_completed_research(self):
"""merge_stats called with provider_cost=total when research has costDollars."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
get_resp = MagicMock()
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
get_resp = MagicMock()
get_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaWaitForResearchBlock — cost_dollars from polling response
# ---------------------------------------------------------------------------
class TestExaWaitForResearchBlockCostTracking:
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total once polling returns completed."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []

View File

@@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -233,11 +232,6 @@ class ExaCreateResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(
provider_cost=research.cost_dollars.total
)
)
return
await asyncio.sleep(check_interval)
@@ -352,9 +346,6 @@ class ExaGetResearchBlock(Block):
yield "cost_searches", research.cost_dollars.num_searches
yield "cost_pages", research.cost_dollars.num_pages
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
yield "error_message", research.error
@@ -441,9 +432,6 @@ class ExaWaitForResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
return

View File

@@ -4,7 +4,6 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -207,6 +206,3 @@ class ExaSearchBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -3,7 +3,6 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -168,6 +167,3 @@ class ExaFindSimilarBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -14,7 +14,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -118,11 +117,6 @@ class GoogleMapsSearchBlock(Block):
input_data.radius,
input_data.max_results,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(places)), provider_cost_type="items"
)
)
for place in places:
yield "place", place

View File

@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.model import SchemaField
from backend.util.request import Requests
@@ -45,13 +45,5 @@ class JinaEmbeddingBlock(Block):
}
data = {"input": input_data.texts, "model": input_data.model}
response = await Requests().post(url, headers=headers, json=data)
resp_json = response.json()
embeddings = [e["embedding"] for e in resp_json["data"]]
usage = resp_json.get("usage", {})
if usage.get("total_tokens"):
self.merge_stats(
NodeExecutionStats(
input_token_count=usage.get("total_tokens", 0),
)
)
embeddings = [e["embedding"] for e in response.json()["data"]]
yield "embeddings", embeddings

View File

@@ -13,7 +13,6 @@ import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
@@ -738,7 +737,6 @@ class LLMResponse(BaseModel):
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
provider_cost: float | None = None
def convert_openai_tool_fmt_to_anthropic(
@@ -773,32 +771,6 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
OpenRouter returns the per-request USD cost in a response header. The
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
attribute. We use try/except AttributeError so that if the SDK ever drops
or renames that attribute, the warning is visible in logs rather than
silently degrading to no cost tracking.
"""
try:
raw_resp = response._response # type: ignore[attr-defined]
except AttributeError:
logger.warning(
"OpenAI SDK response missing _response attribute"
" — OpenRouter cost tracking unavailable"
)
return None
try:
cost_header = raw_resp.headers.get("x-total-cost")
if not cost_header:
return None
return float(cost_header)
except (ValueError, TypeError, AttributeError):
return None
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
@@ -1131,7 +1103,6 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
provider_cost=extract_openrouter_cost(response),
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -1439,7 +1410,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
last_attempt_cost: float | None = None
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
@@ -1457,15 +1427,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
# Merge token counts for every attempt (each call costs tokens).
# provider_cost (actual USD) is tracked separately and only merged
# on success to avoid double-counting across retries.
token_stats = NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
)
self.merge_stats(token_stats)
last_attempt_cost = llm_response.provider_cost
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
@@ -1534,7 +1501,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", response_obj
@@ -1555,7 +1521,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", {"response": response_text}

View File

@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
SaveSequencesResponse,
Sequence,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class CreateCampaignBlock(Block):
@@ -226,12 +226,6 @@ class AddLeadToCampaignBlock(Block):
response = await self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.lead_list)),
provider_cost_type="items",
)
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count

View File

@@ -199,66 +199,6 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
assert block.execution_stats.llm_retry_count == 1
@pytest.mark.asyncio
async def test_retry_cost_uses_last_attempt_only(self):
"""provider_cost is only merged from the final successful attempt.
Intermediate retry costs are intentionally dropped to avoid
double-counting: the cost of failed attempts is captured in
last_attempt_cost only when the loop eventually succeeds.
"""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First attempt: fails validation, returns cost $0.01
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
provider_cost=0.01,
)
# Second attempt: succeeds, returns cost $0.02
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
tool_calls=None,
prompt_tokens=20,
completion_tokens=10,
reasoning=None,
provider_cost=0.02,
)
block.llm_call = mock_llm_call # type: ignore
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
with patch("secrets.token_hex", return_value="test123456"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
# Only the final successful attempt's cost is merged
assert block.execution_stats.provider_cost == pytest.approx(0.02)
# Tokens from both attempts accumulate
assert block.execution_stats.input_token_count == 30
assert block.execution_stats.output_token_count == 15
@pytest.mark.asyncio
async def test_ai_text_summarizer_multiple_chunks(self):
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
@@ -1047,51 +987,3 @@ class TestLlmModelMissing:
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)
class TestExtractOpenRouterCost:
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
def _mk_response(self, headers: dict | None):
response = MagicMock()
if headers is None:
response._response = None
else:
raw = MagicMock()
raw.headers = headers
response._response = raw
return response
def test_extracts_numeric_cost(self):
response = self._mk_response({"x-total-cost": "0.0042"})
assert llm.extract_openrouter_cost(response) == 0.0042
def test_returns_none_when_header_missing(self):
response = self._mk_response({})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_empty_string(self):
response = self._mk_response({"x-total-cost": ""})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_non_numeric(self):
response = self._mk_response({"x-total-cost": "not-a-number"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_no_response_attr(self):
response = MagicMock(spec=[]) # no _response attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_is_none(self):
response = self._mk_response(None)
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_has_no_headers(self):
response = MagicMock()
response._response = MagicMock(spec=[]) # no headers attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_zero_for_zero_cost(self):
"""Zero-cost is a valid value (free tier) and must not become None."""
response = self._mk_response({"x-total-cost": "0"})
assert llm.extract_openrouter_cost(response) == 0.0

View File

@@ -13,7 +13,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -105,10 +104,4 @@ class UnrealTextToSpeechBlock(Block):
input_data.text,
input_data.voice_id,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.text)),
provider_cost_type="characters",
)
)
yield "mp3_url", api_response["OutputUri"]

View File

@@ -9,7 +9,6 @@ shared tool registry as the SDK path.
import asyncio
import base64
import logging
import math
import os
import re
import shutil
@@ -23,7 +22,6 @@ from typing import TYPE_CHECKING, Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from opentelemetry import trace as otel_trace
from backend.copilot.config import CopilotMode
from backend.copilot.context import get_workspace_manager, set_execution_context
@@ -336,7 +334,6 @@ class _BaselineStreamState:
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
cost_usd: float | None = None
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
session_messages: list[ChatMessage] = field(default_factory=list)
@@ -357,7 +354,6 @@ async def _baseline_llm_caller(
state.thinking_stripper = _ThinkingStripper()
round_text = ""
response = None # initialized before try so finally block can access it
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
@@ -434,18 +430,6 @@ async def _baseline_llm_caller(
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Extract OpenRouter cost from response headers (in finally so we
# capture cost even when the stream errors mid-way — we already paid).
# Accumulate across multi-round tool-calling turns.
try:
cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined]
if cost_header:
cost = float(cost_header)
if math.isfinite(cost):
state.cost_usd = (state.cost_usd or 0.0) + max(0.0, cost)
except (AttributeError, ValueError):
pass
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
@@ -1199,22 +1183,8 @@ async def stream_chat_completion_baseline(
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Set cost attributes on OTEL span before closing
# Close Langfuse trace context
if _trace_ctx is not None:
try:
span = otel_trace.get_current_span()
if span and span.is_recording():
span.set_attribute(
"gen_ai.usage.prompt_tokens", state.turn_prompt_tokens
)
span.set_attribute(
"gen_ai.usage.completion_tokens",
state.turn_completion_tokens,
)
if state.cost_usd is not None:
span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd)
except Exception:
logger.debug("[Baseline] Failed to set OTEL cost attributes")
try:
_trace_ctx.__exit__(None, None, None)
except Exception:
@@ -1256,8 +1226,6 @@ async def stream_chat_completion_baseline(
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
log_prefix="[Baseline]",
cost_usd=state.cost_usd,
model=active_model,
)
# Persist structured tool-call history (assistant + tool messages)

View File

@@ -4,7 +4,7 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from openai.types.chat import ChatCompletionToolParam
@@ -631,169 +631,3 @@ class TestPrepareBaselineAttachments:
assert hint == ""
assert blocks == []
class TestBaselineCostExtraction:
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
@pytest.mark.asyncio
async def test_cost_usd_extracted_from_response_header(self):
"""state.cost_usd is set from x-total-cost header when present."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
# Build a mock raw httpx response with the cost header
mock_raw_response = MagicMock()
mock_raw_response.headers = {"x-total-cost": "0.0123"}
# Build a mock async streaming response that yields no chunks but has
# a _response attribute pointing to the mock httpx response
mock_stream_response = MagicMock()
mock_stream_response._response = mock_raw_response
async def empty_aiter():
return
yield # make it an async generator
mock_stream_response.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=mock_stream_response
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.0123)
@pytest.mark.asyncio
async def test_cost_usd_accumulates_across_calls(self):
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
def make_stream_mock(cost: str) -> MagicMock:
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": cost}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
return mock_stream
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "first"}],
tools=[],
state=state,
)
await _baseline_llm_caller(
messages=[{"role": "user", "content": "second"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.03)
@pytest.mark.asyncio
async def test_no_cost_when_header_absent(self):
"""state.cost_usd remains None when response has no x-total-cost header."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd is None
@pytest.mark.asyncio
async def test_cost_extracted_even_when_stream_raises(self):
"""cost_usd is captured in the finally block even when streaming fails."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": "0.005"}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def failing_aiter():
raise RuntimeError("stream error")
yield # make it an async generator
mock_stream.__aiter__ = lambda self: failing_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with (
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
),
pytest.raises(RuntimeError, match="stream error"),
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.005)

View File

@@ -151,8 +151,8 @@ class CoPilotProcessor:
This method is called once per worker thread to set up the async event
loop and initialize any required resources.
DB operations route through DatabaseManagerAsyncClient (RPC) via the
db_accessors pattern — no direct Prisma connection is needed here.
Database is accessed only through DatabaseManager, so we don't need to connect
to Prisma directly.
"""
configure_logging()
set_service_name("CoPilotExecutor")

View File

@@ -15,7 +15,6 @@ from prisma.models import User as PrismaUser
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.db_accessors import user_db
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
@@ -410,12 +409,9 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
try:
user = await user_db().get_user_by_id(user_id)
except Exception:
raise _UserNotFoundError(user_id)
if user.subscription_tier:
return SubscriptionTier(user.subscription_tier)
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
raise _UserNotFoundError(user_id)

View File

@@ -29,7 +29,6 @@ from claude_agent_sdk import (
)
from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from opentelemetry import trace as otel_trace
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
@@ -2373,26 +2372,8 @@ async def stream_chat_completion_sdk(
raise
finally:
# --- Close OTEL context (with cost attributes) ---
# --- Close OTEL context ---
if _otel_ctx is not None:
try:
span = otel_trace.get_current_span()
if span and span.is_recording():
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
span.set_attribute(
"gen_ai.usage.completion_tokens", turn_completion_tokens
)
span.set_attribute(
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
)
span.set_attribute(
"gen_ai.usage.cache_creation_tokens",
turn_cache_creation_tokens,
)
if turn_cost_usd is not None:
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
except Exception:
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
try:
_otel_ctx.__exit__(*sys.exc_info())
except Exception:
@@ -2410,8 +2391,6 @@ async def stream_chat_completion_sdk(
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=config.model,
provider="anthropic",
)
# --- Persist session messages ---

View File

@@ -4,85 +4,17 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
4. Write a PlatformCostLog entry for admin cost tracking.
This module extracts that common logic so both paths stay in sync.
"""
import asyncio
import logging
import math
import re
import threading
from backend.data.db_accessors import platform_cost_db
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
# Hold strong references to in-flight cost log tasks to prevent GC.
_pending_log_tasks: set[asyncio.Task[None]] = set()
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
# fire from the event loop thread; drain_pending_cost_logs iterates the set
# from any caller — the lock prevents RuntimeError from concurrent modification.
_pending_log_tasks_lock = threading.Lock()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
async def _safe_log() -> None:
async with _get_log_semaphore():
try:
await platform_cost_db().log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
with _pending_log_tasks_lock:
_pending_log_tasks.add(task)
def _remove(t: asyncio.Task[None]) -> None:
with _pending_log_tasks_lock:
_pending_log_tasks.discard(t)
task.add_done_callback(_remove)
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
# block/credential in the block_cost_config or credentials_store tables).
COPILOT_BLOCK_ID = "copilot"
COPILOT_CREDENTIAL_ID = "copilot_system"
def _copilot_block_name(log_prefix: str) -> str:
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
if match:
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
tag = log_prefix.strip(" []")
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
async def persist_and_record_usage(
*,
@@ -94,8 +26,6 @@ async def persist_and_record_usage(
cache_creation_tokens: int = 0,
log_prefix: str = "",
cost_usd: float | str | None = None,
model: str | None = None,
provider: str = "open_router",
) -> int:
"""Persist token usage to session and record for rate limiting.
@@ -108,7 +38,6 @@ async def persist_and_record_usage(
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
provider: Cost provider name (e.g. "anthropic", "open_router").
Returns:
The computed total_tokens (prompt + completion; cache excluded).
@@ -118,13 +47,12 @@ async def persist_and_record_usage(
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
no_tokens = (
if (
prompt_tokens <= 0
and completion_tokens <= 0
and cache_read_tokens <= 0
and cache_creation_tokens <= 0
)
if no_tokens and cost_usd is None:
):
return 0
# total_tokens = prompt + completion. Cache tokens are tracked
@@ -145,14 +73,14 @@ async def persist_and_record_usage(
if cache_read_tokens or cache_creation_tokens:
logger.info(
f"{log_prefix} Turn usage: uncached={prompt_tokens}, cache_read={cache_read_tokens},"
f" cache_create={cache_creation_tokens}, output={completion_tokens},"
f" total={total_tokens}, cost_usd={cost_usd}"
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
)
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
f" total={total_tokens}"
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
f"completion={completion_tokens}, total={total_tokens}"
)
if user_id:
@@ -165,54 +93,6 @@ async def persist_and_record_usage(
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
# Log to PlatformCostLog for admin cost dashboard.
# Include entries where cost_usd is set even if token count is 0
# (e.g. fully-cached Anthropic responses where only cache tokens
# accumulate a charge without incrementing total_tokens).
if user_id and (total_tokens > 0 or cost_usd is not None):
cost_float = None
if cost_usd is not None:
try:
val = float(cost_usd)
if math.isfinite(val) and val >= 0:
cost_float = val
except (ValueError, TypeError):
pass
cost_microdollars = usd_to_microdollars(cost_float)
session_id = session.session_id if session else None
if cost_float is not None:
tracking_type = "cost_usd"
tracking_amount = cost_float
else:
tracking_type = "tokens"
tracking_amount = total_tokens
_schedule_cost_log(
PlatformCostEntry(
user_id=user_id,
graph_exec_id=session_id,
block_id=COPILOT_BLOCK_ID,
block_name=_copilot_block_name(log_prefix),
provider=provider,
credential_id=COPILOT_CREDENTIAL_ID,
cost_microdollars=cost_microdollars,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
model=model,
tracking_type=tracking_type,
tracking_amount=tracking_amount,
metadata={
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
"cache_read_tokens": cache_read_tokens,
"cache_creation_tokens": cache_creation_tokens,
"source": "copilot",
},
)
)
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
return total_tokens

View File

@@ -4,7 +4,6 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
calling conventions, session persistence, and rate-limit recording.
"""
import asyncio
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
@@ -280,260 +279,3 @@ class TestRateLimitRecording:
completion_tokens=0,
)
mock_record.assert_not_awaited()
# ---------------------------------------------------------------------------
# PlatformCostLog integration
# ---------------------------------------------------------------------------
class TestPlatformCostLogging:
@pytest.mark.asyncio
async def test_logs_cost_entry_with_cost_usd(self):
"""When cost_usd is provided, tracking_type should be 'cost_usd'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=_make_session(),
user_id="user-cost",
prompt_tokens=200,
completion_tokens=100,
cost_usd=0.005,
model="gpt-4",
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cost"
assert entry.provider == "anthropic"
assert entry.model == "gpt-4"
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 200
assert entry.output_tokens == 100
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["tracking_amount"] == 0.005
assert entry.block_name == "copilot:SDK"
assert entry.graph_exec_id == "sess-test"
@pytest.mark.asyncio
async def test_logs_cost_entry_without_cost_usd(self):
"""When cost_usd is None, tracking_type should be 'tokens'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-tokens",
prompt_tokens=100,
completion_tokens=50,
log_prefix="[Baseline]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 150
assert entry.graph_exec_id is None
assert entry.block_name == "copilot:Baseline"
@pytest.mark.asyncio
async def test_skips_cost_log_when_no_user_id(self):
"""No PlatformCostLog entry when user_id is None."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
await asyncio.sleep(0)
mock_log.assert_not_awaited()
@pytest.mark.asyncio
async def test_cost_usd_invalid_string_falls_back_to_tokens(self):
"""Invalid cost_usd string should fall back to tokens tracking."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-invalid",
prompt_tokens=100,
completion_tokens=50,
cost_usd="not-a-number",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
assert entry.metadata["tracking_type"] == "tokens"
@pytest.mark.asyncio
async def test_cost_usd_string_number_is_parsed(self):
"""String-encoded cost_usd (e.g. from OpenRouter) should be parsed."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-str",
prompt_tokens=100,
completion_tokens=50,
cost_usd="0.01",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars == 10_000
assert entry.metadata["tracking_type"] == "cost_usd"
@pytest.mark.asyncio
async def test_empty_log_prefix_produces_copilot_block_name(self):
"""Empty log_prefix results in block_name='copilot'."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-empty",
prompt_tokens=10,
completion_tokens=5,
log_prefix="",
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.block_name == "copilot"
@pytest.mark.asyncio
async def test_cache_tokens_included_in_metadata(self):
"""Cache token counts should be present in the metadata."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-cache",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=5000,
cache_creation_tokens=300,
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.metadata["cache_read_tokens"] == 5000
assert entry.metadata["cache_creation_tokens"] == 300
assert entry.metadata["source"] == "copilot"
@pytest.mark.asyncio
async def test_logs_cost_only_when_tokens_zero(self):
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-cached",
prompt_tokens=0,
completion_tokens=0,
cost_usd=0.005,
model="claude-3-5-sonnet",
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
# Guard: total_tokens == 0 but cost_usd is set — must still log
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cached"
assert entry.tracking_type == "cost_usd"
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 0
assert entry.output_tokens == 0

View File

@@ -142,9 +142,3 @@ def credit_db():
credit_db = get_database_manager_async_client()
return credit_db
def platform_cost_db():
from backend.util.clients import get_database_manager_async_client
return get_database_manager_async_client()

View File

@@ -96,7 +96,6 @@ from backend.data.notifications import (
remove_notifications_from_batch,
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.platform_cost import log_platform_cost
from backend.data.understanding import (
get_business_understanding,
upsert_business_understanding,
@@ -333,9 +332,6 @@ class DatabaseManager(AppService):
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
update_block_optimized_description = _(update_block_optimized_description)
# ============ Platform Cost Tracking ============ #
log_platform_cost = _(log_platform_cost)
# ============ CoPilot Chat Sessions ============ #
get_chat_session = _(chat_db.get_chat_session)
create_chat_session = _(chat_db.create_chat_session)
@@ -533,9 +529,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Block Descriptions ============ #
get_blocks_needing_optimization = d.get_blocks_needing_optimization
# ============ Platform Cost Tracking ============ #
log_platform_cost = d.log_platform_cost
# ============ CoPilot Chat Sessions ============ #
get_chat_session = d.get_chat_session
create_chat_session = d.create_chat_session

View File

@@ -104,11 +104,6 @@ class User(BaseModel):
description="User timezone (IANA timezone identifier or 'not-set')",
)
# Subscription / rate-limit tier
subscription_tier: str | None = Field(
default=None, description="Subscription tier (FREE, PRO, BUSINESS, ENTERPRISE)"
)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
@@ -163,7 +158,6 @@ class User(BaseModel):
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
subscription_tier=prisma_user.subscriptionTier,
)
@@ -825,17 +819,6 @@ class RefundRequest(BaseModel):
updated_at: datetime
ProviderCostType = Literal[
"cost_usd", # Actual USD cost reported by the provider
"tokens", # LLM token counts (sum of input + output)
"characters", # Per-character billing (TTS providers)
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
"walltime_seconds", # Per-second billing incl. queue/polling
"per_run", # Per-API-call billing with fixed cost
"items", # Per-item billing (lead/organization/result count)
]
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
@@ -855,39 +838,32 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
extra_cost: int = 0
extra_steps: int = 0
provider_cost: float | None = None
# Type of the provider-reported cost/usage captured above. When set
# by a block, resolve_tracking honors this directly instead of
# guessing from provider name.
provider_cost_type: Optional[ProviderCostType] = None
# Moderation fields
cleared_inputs: Optional[dict[str, list[str]]] = None
cleared_outputs: Optional[dict[str, list[str]]] = None
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
"""Mutate this instance by adding another NodeExecutionStats.
Avoids calling model_dump() twice per merge (called on every
merge_stats() from ~20+ blocks); reads via getattr/vars instead.
"""
"""Mutate this instance by adding another NodeExecutionStats."""
if not isinstance(other, NodeExecutionStats):
return NotImplemented
for key in type(other).model_fields:
value = getattr(other, key)
if value is None:
# Never overwrite an existing value with None
continue
current = getattr(self, key, None)
if current is None:
# Field doesn't exist yet or is None, just set it
stats_dict = other.model_dump()
current_stats = self.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it
setattr(self, key, value)
elif isinstance(value, dict) and isinstance(current, dict):
current.update(value)
elif isinstance(value, (int, float)) and isinstance(current, (int, float)):
setattr(self, key, current + value)
elif isinstance(value, list) and isinstance(current, list):
current.extend(value)
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
setattr(self, key, current_stats[key])
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
setattr(self, key, current_stats[key] + value)
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
setattr(self, key, current_stats[key])
else:
setattr(self, key, value)

View File

@@ -1,7 +1,7 @@
import pytest
from pydantic import SecretStr
from backend.data.model import HostScopedCredentials, NodeExecutionStats
from backend.data.model import HostScopedCredentials
class TestHostScopedCredentials:
@@ -166,84 +166,3 @@ class TestHostScopedCredentials:
)
assert creds.matches_url(test_url) == expected
class TestNodeExecutionStatsIadd:
def test_adds_numeric_fields(self):
a = NodeExecutionStats(input_token_count=100, output_token_count=50)
b = NodeExecutionStats(input_token_count=200, output_token_count=30)
a += b
assert a.input_token_count == 300
assert a.output_token_count == 80
def test_none_does_not_overwrite(self):
a = NodeExecutionStats(provider_cost=0.5, error="some error")
b = NodeExecutionStats(provider_cost=None, error=None)
a += b
assert a.provider_cost == 0.5
assert a.error == "some error"
def test_none_is_skipped_preserving_existing_value(self):
a = NodeExecutionStats(input_token_count=100)
b = NodeExecutionStats()
a += b
assert a.input_token_count == 100
def test_dict_fields_are_merged(self):
a = NodeExecutionStats(
cleared_inputs={"field1": ["val1"]},
)
b = NodeExecutionStats(
cleared_inputs={"field2": ["val2"]},
)
a += b
assert a.cleared_inputs == {"field1": ["val1"], "field2": ["val2"]}
def test_returns_self(self):
a = NodeExecutionStats()
b = NodeExecutionStats(input_token_count=10)
result = a.__iadd__(b)
assert result is a
def test_not_implemented_for_non_stats(self):
a = NodeExecutionStats()
result = a.__iadd__("not a stats") # type: ignore[arg-type]
assert result is NotImplemented
def test_error_none_does_not_clear_existing_error(self):
a = NodeExecutionStats(error="existing error")
b = NodeExecutionStats(error=None)
a += b
assert a.error == "existing error"
def test_provider_cost_none_does_not_clear_existing_cost(self):
a = NodeExecutionStats(provider_cost=0.05)
b = NodeExecutionStats(provider_cost=None)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_accumulates_when_both_set(self):
a = NodeExecutionStats(provider_cost=0.01)
b = NodeExecutionStats(provider_cost=0.02)
a += b
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
def test_provider_cost_first_write_from_none(self):
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost=0.05)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_type_first_write_from_none(self):
"""Writing provider_cost_type into a stats with None sets it."""
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost_type="characters")
a += b
assert a.provider_cost_type == "characters"
def test_provider_cost_type_none_does_not_overwrite(self):
"""A None provider_cost_type from other must not clear an existing value."""
a = NodeExecutionStats(provider_cost_type="tokens")
b = NodeExecutionStats()
a += b
assert a.provider_cost_type == "tokens"

View File

@@ -1,390 +0,0 @@
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
logger = logging.getLogger(__name__)
MICRODOLLARS_PER_USD = 1_000_000
# Dashboard query limits — keep in sync with the SQL queries below
MAX_PROVIDER_ROWS = 500
MAX_USER_ROWS = 100
# Default date range for dashboard queries when no start date is provided.
# Prevents full-table scans on large deployments.
DEFAULT_DASHBOARD_DAYS = 30
def usd_to_microdollars(cost_usd: float | None) -> int | None:
"""Convert a USD amount (float) to microdollars (int). None-safe."""
if cost_usd is None:
return None
return round(cost_usd * MICRODOLLARS_PER_USD)
class PlatformCostEntry(BaseModel):
user_id: str
graph_exec_id: str | None = None
node_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
block_id: str
block_name: str
provider: str
credential_id: str
cost_microdollars: int | None = None
input_tokens: int | None = None
output_tokens: int | None = None
data_size: int | None = None
duration: float | None = None
model: str | None = None
tracking_type: str | None = None
tracking_amount: float | None = None
metadata: dict[str, Any] | None = None
async def log_platform_cost(entry: PlatformCostEntry) -> None:
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"PlatformCostLog"
("id", "createdAt", "userId", "graphExecId", "nodeExecId",
"graphId", "nodeId", "blockId", "blockName", "provider",
"credentialId", "costMicrodollars", "inputTokens", "outputTokens",
"dataSize", "duration", "model", "trackingType", "trackingAmount",
"metadata")
VALUES (
gen_random_uuid(), NOW(), $1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14, $15, $16, $17, $18::jsonb
)
""",
entry.user_id,
entry.graph_exec_id,
entry.node_exec_id,
entry.graph_id,
entry.node_id,
entry.block_id,
entry.block_name,
# Normalize to lowercase so the (provider, createdAt) index is always
# used without LOWER() on the read side.
entry.provider.lower(),
entry.credential_id,
entry.cost_microdollars,
entry.input_tokens,
entry.output_tokens,
entry.data_size,
entry.duration,
entry.model,
entry.tracking_type,
entry.tracking_amount,
_json_or_none(entry.metadata),
)
# Bound the number of concurrent cost-log DB inserts to prevent unbounded
# task/connection growth under sustained load or DB slowness.
_log_semaphore = asyncio.Semaphore(50)
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
"""Fire-and-forget wrapper that never raises."""
try:
async with _log_semaphore:
await log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
def _json_or_none(data: dict[str, Any] | None) -> str | None:
if data is None:
return None
return json.dumps(data)
def _mask_email(email: str | None) -> str | None:
"""Mask an email address to reduce PII exposure in admin API responses.
Turns 'user@example.com' into 'us***@example.com'.
Handles short local parts gracefully (e.g. 'a@b.com''a***@b.com').
"""
if not email:
return email
at = email.find("@")
if at < 0:
return "***"
local = email[:at]
domain = email[at:]
visible = local[:2] if len(local) >= 2 else local[:1]
return f"{visible}***{domain}"
class ProviderCostSummary(BaseModel):
provider: str
tracking_type: str | None = None
total_cost_microdollars: int
total_input_tokens: int
total_output_tokens: int
total_duration_seconds: float = 0.0
total_tracking_amount: float = 0.0
request_count: int
class UserCostSummary(BaseModel):
user_id: str | None = None
email: str | None = None
total_cost_microdollars: int
total_input_tokens: int
total_output_tokens: int
request_count: int
class CostLogRow(BaseModel):
id: str
created_at: datetime
user_id: str | None = None
email: str | None = None
graph_exec_id: str | None = None
node_exec_id: str | None = None
block_name: str
provider: str
tracking_type: str | None = None
cost_microdollars: int | None = None
input_tokens: int | None = None
output_tokens: int | None = None
duration: float | None = None
model: str | None = None
class PlatformCostDashboard(BaseModel):
by_provider: list[ProviderCostSummary]
by_user: list[UserCostSummary]
total_cost_microdollars: int
total_requests: int
total_users: int
def _build_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
) -> tuple[str, list[Any]]:
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
# Provider names are normalized to lowercase at write time so a plain
# equality check is sufficient and the (provider, createdAt) index is used.
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
return (" AND ".join(clauses) if clauses else "TRUE", params)
async def get_platform_cost_dashboard(
start: datetime | None = None,
end: datetime | None = None,
provider: str | None = None,
user_id: str | None = None,
) -> PlatformCostDashboard:
"""Aggregate platform cost logs for the admin dashboard.
Note: by_provider rows are keyed on (provider, tracking_type). A single
provider can therefore appear in multiple rows if it has entries with
different billing models (e.g. "openai" with both "tokens" and "cost_usd"
if pricing is later added for some entries). Frontend treats each row
independently rather than as a provider primary key.
Defaults to the last DEFAULT_DASHBOARD_DAYS days when no start date is
provided to avoid full-table scans on large deployments.
"""
if start is None:
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_p, params_p = _build_where(start, end, provider, user_id, "p")
by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT
p."provider",
p."trackingType" AS tracking_type,
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider", p."trackingType"
ORDER BY total_cost DESC
LIMIT {MAX_PROVIDER_ROWS}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT
p."userId" AS user_id,
u."email",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_p}
GROUP BY p."userId", u."email"
ORDER BY total_cost DESC
LIMIT {MAX_USER_ROWS}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
),
)
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
total_cost = sum(r["total_cost"] for r in by_provider_rows)
total_requests = sum(r["request_count"] for r in by_provider_rows)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
provider=r["provider"],
tracking_type=r.get("tracking_type"),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
total_duration_seconds=r.get("total_duration", 0.0),
total_tracking_amount=r.get("total_tracking_amount", 0.0),
request_count=r["request_count"],
)
for r in by_provider_rows
],
by_user=[
UserCostSummary(
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
request_count=r["request_count"],
)
for r in by_user_rows
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
total_users=total_users,
)
async def get_platform_cost_logs(
start: datetime | None = None,
end: datetime | None = None,
provider: str | None = None,
user_id: str | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(start, end, provider, user_id, "p")
offset = (page - 1) * page_size
limit_idx = len(params) + 1
offset_idx = len(params) + 2
count_rows, rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT COUNT(*)::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_sql}
""",
*params,
),
query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx} OFFSET ${offset_idx}
""",
*params,
page_size,
offset,
),
)
total = count_rows[0]["cnt"] if count_rows else 0
logs = [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
duration=r.get("duration"),
model=r.get("model"),
)
for r in rows
]
return logs, total

View File

@@ -1,266 +0,0 @@
"""Unit tests for helpers and async functions in platform_cost module."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from .platform_cost import (
PlatformCostEntry,
_build_where,
_json_or_none,
get_platform_cost_dashboard,
get_platform_cost_logs,
log_platform_cost,
log_platform_cost_safe,
)
class TestJsonOrNone:
def test_returns_none_for_none(self):
assert _json_or_none(None) is None
def test_returns_json_string_for_dict(self):
result = _json_or_none({"key": "value", "num": 42})
assert result is not None
assert '"key"' in result
assert '"value"' in result
def test_returns_json_for_empty_dict(self):
assert _json_or_none({}) == "{}"
class TestBuildWhere:
def test_no_filters_returns_true(self):
sql, params = _build_where(None, None, None, None)
assert sql == "TRUE"
assert params == []
def test_start_only(self):
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
sql, params = _build_where(dt, None, None, None)
assert '"createdAt" >= $1::timestamptz' in sql
assert params == [dt]
def test_end_only(self):
dt = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_where(None, dt, None, None)
assert '"createdAt" <= $1::timestamptz' in sql
assert params == [dt]
def test_provider_only(self):
# Provider names are normalized to lowercase at write time, so the
# filter uses a plain equality check. The input is also lowercased so
# "OpenAI" and "openai" both match stored rows.
sql, params = _build_where(None, None, "OpenAI", None)
assert '"provider" = $1' in sql
assert params == ["openai"]
def test_user_id_only(self):
sql, params = _build_where(None, None, None, "user-123")
assert '"userId" = $1' in sql
assert params == ["user-123"]
def test_all_filters(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_where(start, end, "Anthropic", "u1")
assert "$1" in sql
assert "$2" in sql
assert "$3" in sql
assert "$4" in sql
assert len(params) == 4
# Provider is lowercased at filter time to match stored lowercase values.
assert params == [start, end, "anthropic", "u1"]
def test_table_alias(self):
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
sql, params = _build_where(dt, None, None, None, table_alias="p")
assert 'p."createdAt"' in sql
assert params == [dt]
def test_clauses_joined_with_and(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, _ = _build_where(start, end, None, None)
assert " AND " in sql
def _make_entry(**overrides: object) -> PlatformCostEntry:
return PlatformCostEntry.model_validate(
{
"user_id": "user-1",
"block_id": "block-1",
"block_name": "TestBlock",
"provider": "openai",
"credential_id": "cred-1",
**overrides,
}
)
class TestLogPlatformCost:
@pytest.mark.asyncio
async def test_calls_execute_raw_with_schema(self):
mock_exec = AsyncMock()
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
entry = _make_entry(
input_tokens=100,
output_tokens=50,
cost_microdollars=5000,
model="gpt-4",
metadata={"key": "val"},
)
await log_platform_cost(entry)
mock_exec.assert_awaited_once()
args = mock_exec.call_args
assert args[0][1] == "user-1" # user_id is first param
assert args[0][6] == "block-1" # block_id
assert args[0][7] == "TestBlock" # block_name
@pytest.mark.asyncio
async def test_metadata_none_passes_none(self):
mock_exec = AsyncMock()
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
entry = _make_entry(metadata=None)
await log_platform_cost(entry)
args = mock_exec.call_args
assert args[0][-1] is None # last arg is metadata json
class TestLogPlatformCostSafe:
@pytest.mark.asyncio
async def test_does_not_raise_on_error(self):
with patch(
"backend.data.platform_cost.execute_raw_with_schema",
new=AsyncMock(side_effect=RuntimeError("DB down")),
):
entry = _make_entry()
await log_platform_cost_safe(entry)
@pytest.mark.asyncio
async def test_succeeds_when_no_error(self):
mock_exec = AsyncMock()
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
entry = _make_entry()
await log_platform_cost_safe(entry)
mock_exec.assert_awaited_once()
class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_dashboard_with_data(self):
provider_rows = [
{
"provider": "openai",
"tracking_type": "tokens",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"total_duration": 10.5,
"request_count": 3,
}
]
user_rows = [
{
"user_id": "u1",
"email": "a@b.com",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"request_count": 3,
}
]
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
assert dashboard.total_requests == 3
assert dashboard.total_users == 1
assert len(dashboard.by_provider) == 1
assert dashboard.by_provider[0].provider == "openai"
assert dashboard.by_provider[0].tracking_type == "tokens"
assert dashboard.by_provider[0].total_duration_seconds == 10.5
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].email == "a***@b.com"
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
assert dashboard.total_requests == 0
assert dashboard.total_users == 0
assert dashboard.by_provider == []
assert dashboard.by_user == []
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 3
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql
class TestGetPlatformCostLogs:
@pytest.mark.asyncio
async def test_returns_logs_and_total(self):
count_rows = [{"cnt": 1}]
log_rows = [
{
"id": "log-1",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": "a@b.com",
"graph_exec_id": "g1",
"node_exec_id": "n1",
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 5000,
"input_tokens": 100,
"output_tokens": 50,
"duration": 1.5,
"model": "gpt-4",
}
]
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=1, page_size=10)
assert total == 1
assert len(logs) == 1
assert logs[0].id == "log-1"
assert logs[0].provider == "openai"
assert logs[0].model == "gpt-4"
@pytest.mark.asyncio
async def test_returns_empty_when_no_data(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0
assert logs == []
@pytest.mark.asyncio
async def test_pagination_offset(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
second_call_args = mock_query.call_args_list[1][0]
assert 25 in second_call_args # page_size
assert 50 in second_call_args # offset = (3-1) * 25
@pytest.mark.asyncio
async def test_empty_count_returns_zero(self):
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0

View File

@@ -1,291 +0,0 @@
"""Helpers for platform cost tracking on system-credential block executions."""
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, cast
from backend.blocks._base import Block, BlockSchema
from backend.copilot.token_tracking import _pending_log_tasks as _copilot_tasks
from backend.copilot.token_tracking import (
_pending_log_tasks_lock as _copilot_tasks_lock,
)
from backend.data.execution import NodeExecutionEntry
from backend.data.model import NodeExecutionStats
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import is_system_credential
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from backend.data.db_manager import DatabaseManagerAsyncClient
logger = logging.getLogger(__name__)
# Provider groupings by billing model — used when the block didn't explicitly
# declare stats.provider_cost_type and we fall back to provider-name
# heuristics. Values match ProviderName enum values.
_CHARACTER_BILLED_PROVIDERS = frozenset(
{ProviderName.D_ID.value, ProviderName.ELEVENLABS.value}
)
_WALLTIME_BILLED_PROVIDERS = frozenset(
{
ProviderName.FAL.value,
ProviderName.REVID.value,
ProviderName.REPLICATE.value,
}
)
# Hold strong references to in-flight log tasks so the event loop doesn't
# garbage-collect them mid-execution. Tasks remove themselves on completion.
# _pending_log_tasks_lock guards all reads and writes: worker threads call
# discard() via done callbacks while drain_pending_cost_logs() iterates.
_pending_log_tasks: set[asyncio.Task] = set()
_pending_log_tasks_lock = threading.Lock()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads. Key by loop instance
# so each executor worker thread gets its own semaphore.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
"""Await all in-flight cost log tasks with a timeout.
Drains both the executor cost log tasks (_pending_log_tasks in this module,
used for block execution cost tracking via DatabaseManagerAsyncClient) and
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
copilot LLM turns via platform_cost_db()).
Call this during graceful shutdown to flush pending INSERT tasks before
the process exits. Tasks that don't complete within `timeout` seconds are
abandoned and their failures are already logged by _safe_log.
"""
# asyncio.wait() requires all tasks to belong to the running event loop.
# _pending_log_tasks is shared across executor worker threads (each with
# its own loop), so filter to only tasks owned by the current loop.
# Acquire the lock to take a consistent snapshot (worker threads call
# discard() via done callbacks concurrently with this iteration).
current_loop = asyncio.get_running_loop()
with _pending_log_tasks_lock:
all_pending = [t for t in _pending_log_tasks if t.get_loop() is current_loop]
if all_pending:
logger.info("Draining %d executor cost log task(s)", len(all_pending))
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d executor cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
with _copilot_tasks_lock:
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
if copilot_pending:
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d copilot cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
def _schedule_log(
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
) -> None:
async def _safe_log() -> None:
async with _get_log_semaphore():
try:
await db_client.log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
with _pending_log_tasks_lock:
_pending_log_tasks.add(task)
def _remove(t: asyncio.Task) -> None:
with _pending_log_tasks_lock:
_pending_log_tasks.discard(t)
task.add_done_callback(_remove)
def _extract_model_name(raw: str | dict | None) -> str | None:
"""Return a string model name from a block input field, or None.
Handles str (returned as-is), dict (e.g. an enum wrapper, skipped), and
None (no model field). Unexpected types are coerced to str as a fallback.
"""
if raw is None:
return None
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
return None
return str(raw)
def resolve_tracking(
provider: str,
stats: NodeExecutionStats,
input_data: dict[str, Any],
) -> tuple[str, float]:
"""Return (tracking_type, tracking_amount) based on provider billing model.
Preference order:
1. Block-declared: if the block set `provider_cost_type` on its stats,
honor it directly (paired with `provider_cost` as the amount).
2. Heuristic fallback: infer from `provider_cost`/token counts, then
from provider name for per-character / per-second billing.
"""
# 1. Block explicitly declared its cost type (only when an amount is present)
if stats.provider_cost_type and stats.provider_cost is not None:
return stats.provider_cost_type, stats.provider_cost
# 2. Provider returned actual USD cost (OpenRouter, Exa)
if stats.provider_cost is not None:
return "cost_usd", stats.provider_cost
# 3. LLM providers: track by tokens
if stats.input_token_count or stats.output_token_count:
return "tokens", float(
(stats.input_token_count or 0) + (stats.output_token_count or 0)
)
# 4. Provider-specific billing heuristics
# TTS: billed per character of input text
if provider == ProviderName.UNREAL_SPEECH.value:
text = input_data.get("text", "")
return "characters", float(len(text)) if isinstance(text, str) else 0.0
# D-ID + ElevenLabs voice: billed per character of script
if provider in _CHARACTER_BILLED_PROVIDERS:
text = (
input_data.get("script_input", "")
or input_data.get("text", "")
or input_data.get("script", "") # VideoNarrationBlock uses `script`
)
return "characters", float(len(text)) if isinstance(text, str) else 0.0
# E2B: billed per second of sandbox time
if provider == ProviderName.E2B.value:
return "sandbox_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
# Video/image gen: walltime includes queue + generation + polling
if provider in _WALLTIME_BILLED_PROVIDERS:
return "walltime_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
# All billed per API call - count 1 per block execution.
return "per_run", 1.0
async def log_system_credential_cost(
node_exec: NodeExecutionEntry,
block: Block,
stats: NodeExecutionStats,
db_client: "DatabaseManagerAsyncClient",
) -> None:
"""Check if a system credential was used and log the platform cost.
Routes through DatabaseManagerAsyncClient so the write goes via the
message-passing DB service rather than calling Prisma directly (which
is not connected in the executor process).
Logs only the first matching system credential field (one log per
execution). Any unexpected error is caught and logged — cost logging
is strictly best-effort and must never disrupt block execution.
Note: costMicrodollars is left null for providers that don't return
a USD cost. The credit_cost in metadata captures our internal credit
charge as a proxy.
"""
try:
if node_exec.execution_context.dry_run:
return
input_data = node_exec.inputs
input_model = cast(type[BlockSchema], block.input_schema)
for field_name in input_model.get_credentials_fields():
cred_data = input_data.get(field_name)
if not cred_data or not isinstance(cred_data, dict):
continue
cred_id = cred_data.get("id", "")
if not cred_id or not is_system_credential(cred_id):
continue
model_name = _extract_model_name(input_data.get("model"))
credit_cost, _ = block_usage_cost(block=block, input_data=input_data)
provider_name = cred_data.get("provider", "unknown")
tracking_type, tracking_amount = resolve_tracking(
provider=provider_name,
stats=stats,
input_data=input_data,
)
# Only treat provider_cost as USD when the tracking type says so.
# For other types (items, characters, per_run, ...) the
# provider_cost field holds the raw amount, not a dollar value.
# Use tracking_amount (the normalized value from resolve_tracking)
# rather than raw stats.provider_cost to avoid unit mismatches.
cost_microdollars = None
if tracking_type == "cost_usd":
cost_microdollars = usd_to_microdollars(tracking_amount)
meta: dict[str, Any] = {
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
}
if credit_cost is not None:
meta["credit_cost"] = credit_cost
if stats.provider_cost is not None:
# Use 'provider_cost_raw' — the value's unit varies by tracking
# type (USD for cost_usd, count for items/characters/per_run, etc.)
meta["provider_cost_raw"] = stats.provider_cost
_schedule_log(
db_client,
PlatformCostEntry(
user_id=node_exec.user_id,
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
graph_id=node_exec.graph_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block_name=block.name,
provider=provider_name,
credential_id=cred_id,
cost_microdollars=cost_microdollars,
input_tokens=stats.input_token_count,
output_tokens=stats.output_token_count,
data_size=stats.output_size if stats.output_size > 0 else None,
duration=stats.walltime if stats.walltime > 0 else None,
model=model_name,
tracking_type=tracking_type,
tracking_amount=tracking_amount,
metadata=meta,
),
)
return # One log per execution is enough
except Exception:
logger.exception("log_system_credential_cost failed unexpectedly")

View File

@@ -45,10 +45,6 @@ from backend.data.notifications import (
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cost_tracking import (
drain_pending_cost_logs,
log_system_credential_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.util import json
@@ -696,15 +692,6 @@ class ExecutionProcessor:
stats=graph_stats,
)
# Log platform cost if system credentials were used (only on success)
if status == ExecutionStatus.COMPLETED:
await log_system_credential_cost(
node_exec=node_exec,
block=node.block,
stats=execution_stats,
db_client=db_client,
)
return execution_stats
@async_time_measured
@@ -2057,23 +2044,14 @@ class ExecutionManager(AppProcess):
prefix + " [cancel-consumer]",
)
# Drain any in-flight cost log tasks before exit so we don't silently
# drop INSERT operations during deployments.
loop = getattr(self, "node_execution_loop", None)
if loop is not None and loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
drain_pending_cost_logs(), loop
).result(timeout=10)
logger.info(f"{prefix} ✅ Cost log tasks drained")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Failed to drain cost log tasks: {e}")
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()
# ------- UTILITIES ------- #
def get_db_client() -> "DatabaseManagerClient":
return get_database_manager_client()

View File

@@ -1,567 +0,0 @@
"""Unit tests for resolve_tracking and log_system_credential_cost."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext, NodeExecutionEntry
from backend.data.model import NodeExecutionStats
from backend.executor.cost_tracking import log_system_credential_cost, resolve_tracking
# ---------------------------------------------------------------------------
# resolve_tracking
# ---------------------------------------------------------------------------
class TestResolveTracking:
def _stats(self, **overrides: Any) -> NodeExecutionStats:
return NodeExecutionStats(**overrides)
def test_provider_cost_returns_cost_usd(self):
stats = self._stats(provider_cost=0.0042)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "cost_usd"
assert amt == 0.0042
def test_token_counts_return_tokens(self):
stats = self._stats(input_token_count=300, output_token_count=100)
tt, amt = resolve_tracking("anthropic", stats, {})
assert tt == "tokens"
assert amt == 400.0
def test_token_counts_only_input(self):
stats = self._stats(input_token_count=500)
tt, amt = resolve_tracking("groq", stats, {})
assert tt == "tokens"
assert amt == 500.0
def test_unreal_speech_returns_characters(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": "Hello world"})
assert tt == "characters"
assert amt == 11.0
def test_unreal_speech_empty_text(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": ""})
assert tt == "characters"
assert amt == 0.0
def test_unreal_speech_non_string_text(self):
stats = self._stats()
tt, amt = resolve_tracking("unreal_speech", stats, {"text": 123})
assert tt == "characters"
assert amt == 0.0
def test_d_id_uses_script_input(self):
stats = self._stats()
tt, amt = resolve_tracking("d_id", stats, {"script_input": "Hello"})
assert tt == "characters"
assert amt == 5.0
def test_elevenlabs_uses_text(self):
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Say this"})
assert tt == "characters"
assert amt == 8.0
def test_elevenlabs_fallback_to_text_when_no_script_input(self):
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Fallback text"})
assert tt == "characters"
assert amt == 13.0
def test_elevenlabs_uses_script_field(self):
"""VideoNarrationBlock (elevenlabs) uses `script` field, not script_input/text."""
stats = self._stats()
tt, amt = resolve_tracking("elevenlabs", stats, {"script": "Narration"})
assert tt == "characters"
assert amt == 9.0
def test_block_declared_cost_type_items(self):
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "items"
assert amt == 5.0
def test_block_declared_cost_type_characters(self):
"""TTS block can declare characters directly, bypassing input_data lookup."""
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
tt, amt = resolve_tracking("unreal_speech", stats, {})
assert tt == "characters"
assert amt == 42.0
def test_block_declared_cost_type_wins_over_tokens(self):
"""provider_cost_type takes precedence over token-based heuristic."""
stats = self._stats(
provider_cost=1.0,
provider_cost_type="per_run",
input_token_count=500,
)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "per_run"
assert amt == 1.0
def test_e2b_returns_sandbox_seconds(self):
stats = self._stats(walltime=45.123)
tt, amt = resolve_tracking("e2b", stats, {})
assert tt == "sandbox_seconds"
assert amt == 45.123
def test_e2b_no_walltime(self):
stats = self._stats(walltime=0)
tt, amt = resolve_tracking("e2b", stats, {})
assert tt == "sandbox_seconds"
assert amt == 0.0
def test_fal_returns_walltime(self):
stats = self._stats(walltime=12.5)
tt, amt = resolve_tracking("fal", stats, {})
assert tt == "walltime_seconds"
assert amt == 12.5
def test_revid_returns_walltime(self):
stats = self._stats(walltime=60.0)
tt, amt = resolve_tracking("revid", stats, {})
assert tt == "walltime_seconds"
assert amt == 60.0
def test_replicate_returns_walltime(self):
stats = self._stats(walltime=30.0)
tt, amt = resolve_tracking("replicate", stats, {})
assert tt == "walltime_seconds"
assert amt == 30.0
def test_unknown_provider_returns_per_run(self):
stats = self._stats()
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "per_run"
assert amt == 1.0
def test_provider_cost_takes_precedence_over_tokens(self):
stats = self._stats(
provider_cost=0.01, input_token_count=500, output_token_count=200
)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "cost_usd"
assert amt == 0.01
def test_provider_cost_zero_is_not_none(self):
"""provider_cost=0.0 is falsy but should still be tracked as cost_usd
(e.g. free-tier or fully-cached responses from OpenRouter)."""
stats = self._stats(provider_cost=0.0)
tt, amt = resolve_tracking("open_router", stats, {})
assert tt == "cost_usd"
assert amt == 0.0
def test_tokens_take_precedence_over_provider_specific(self):
stats = self._stats(input_token_count=100, walltime=10.0)
tt, amt = resolve_tracking("fal", stats, {})
assert tt == "tokens"
assert amt == 100.0
# ---------------------------------------------------------------------------
# log_system_credential_cost
# ---------------------------------------------------------------------------
def _make_db_client() -> MagicMock:
db_client = MagicMock()
db_client.log_platform_cost = AsyncMock()
return db_client
def _make_block(has_credentials: bool = True) -> MagicMock:
block = MagicMock()
block.name = "TestBlock"
input_schema = MagicMock()
if has_credentials:
input_schema.get_credentials_fields.return_value = {"credentials": MagicMock()}
else:
input_schema.get_credentials_fields.return_value = {}
block.input_schema = input_schema
return block
def _make_node_exec(
inputs: dict | None = None,
dry_run: bool = False,
) -> NodeExecutionEntry:
return NodeExecutionEntry(
user_id="user-1",
graph_exec_id="gx-1",
graph_id="g-1",
graph_version=1,
node_exec_id="nx-1",
node_id="n-1",
block_id="b-1",
inputs=inputs or {},
execution_context=ExecutionContext(dry_run=dry_run),
)
class TestLogSystemCredentialCost:
@pytest.mark.asyncio
async def test_skips_dry_run(self):
db_client = _make_db_client()
node_exec = _make_node_exec(dry_run=True)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_no_credential_fields(self):
db_client = _make_db_client()
node_exec = _make_node_exec(inputs={})
block = _make_block(has_credentials=False)
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_cred_data_missing(self):
db_client = _make_db_client()
node_exec = _make_node_exec(inputs={})
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_not_system_credential(self):
db_client = _make_db_client()
with patch(
"backend.executor.cost_tracking.is_system_credential",
return_value=False,
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "user-cred-123", "provider": "openai"},
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()
@pytest.mark.asyncio
async def test_logs_with_system_credential(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(10, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-1", "provider": "openai"},
"model": "gpt-4",
}
)
block = _make_block()
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
db_client.log_platform_cost.assert_awaited_once()
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.user_id == "user-1"
assert entry.provider == "openai"
assert entry.block_name == "TestBlock"
assert entry.model == "gpt-4"
assert entry.input_tokens == 500
assert entry.output_tokens == 200
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 700.0
assert entry.metadata["credit_cost"] == 10
@pytest.mark.asyncio
async def test_logs_with_provider_cost(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(5, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-2", "provider": "open_router"},
}
)
block = _make_block()
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.cost_microdollars == 1500
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["provider_cost_raw"] == 0.0015
@pytest.mark.asyncio
async def test_model_name_enum_converted_to_str(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
from enum import Enum
class FakeModel(Enum):
GPT4 = "gpt-4"
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
"model": FakeModel.GPT4,
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.model == "FakeModel.GPT4"
@pytest.mark.asyncio
async def test_model_name_dict_becomes_none(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
"model": {"nested": "value"},
}
)
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.model is None
@pytest.mark.asyncio
async def test_does_not_raise_when_block_usage_cost_raises(self):
"""log_system_credential_cost must swallow exceptions from block_usage_cost."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
side_effect=RuntimeError("pricing lookup failed"),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
}
)
block = _make_block()
stats = NodeExecutionStats()
# Should not raise — outer except must catch block_usage_cost error
await log_system_credential_cost(node_exec, block, stats, db_client)
@pytest.mark.asyncio
async def test_round_instead_of_int_for_microdollars(self):
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "openai"},
}
)
block = _make_block()
# 0.0015 * 1_000_000 = 1499.9999999... with float math
# round() should give 1500, int() would give 1499
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.cost_microdollars == 1500
@pytest.mark.asyncio
async def test_per_run_metadata_has_no_provider_cost_raw(self):
"""For per-run providers (google_maps etc), provider_cost_raw is absent
from metadata since stats.provider_cost is None."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "google_maps"},
}
)
block = _make_block()
stats = NodeExecutionStats() # no provider_cost
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.tracking_type == "per_run"
assert "provider_cost_raw" not in (entry.metadata or {})
# ---------------------------------------------------------------------------
# merge_stats accumulation
# ---------------------------------------------------------------------------
class TestMergeStats:
"""Tests for NodeExecutionStats accumulation via += (used by Block.merge_stats)."""
def test_accumulates_output_size(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(output_size=10)
stats += NodeExecutionStats(output_size=25)
assert stats.output_size == 35
def test_accumulates_tokens(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(input_token_count=100, output_token_count=50)
stats += NodeExecutionStats(input_token_count=200, output_token_count=150)
assert stats.input_token_count == 300
assert stats.output_token_count == 200
def test_preserves_provider_cost(self):
stats = NodeExecutionStats()
stats += NodeExecutionStats(provider_cost=0.005)
stats += NodeExecutionStats(output_size=10)
assert stats.provider_cost == 0.005
assert stats.output_size == 10
def test_provider_cost_accumulates(self):
"""Multiple merge_stats with provider_cost should sum (multi-round
tool-calling in copilot / retries can report cost separately)."""
stats = NodeExecutionStats()
stats += NodeExecutionStats(provider_cost=0.001)
stats += NodeExecutionStats(provider_cost=0.002)
stats += NodeExecutionStats(provider_cost=0.003)
assert stats.provider_cost == pytest.approx(0.006)
def test_provider_cost_none_does_not_overwrite(self):
"""A None provider_cost must not wipe a previously-set value."""
stats = NodeExecutionStats(provider_cost=0.01)
stats += NodeExecutionStats() # provider_cost=None by default
assert stats.provider_cost == 0.01
def test_provider_cost_type_last_write_wins(self):
"""provider_cost_type is a Literal — last set value wins on merge."""
stats = NodeExecutionStats(provider_cost_type="tokens")
stats += NodeExecutionStats(provider_cost_type="items")
assert stats.provider_cost_type == "items"
# ---------------------------------------------------------------------------
# on_node_execution -> log_system_credential_cost integration
# ---------------------------------------------------------------------------
class TestManagerCostTrackingIntegration:
@pytest.mark.asyncio
async def test_log_called_with_accumulated_stats(self):
"""Verify that log_system_credential_cost receives stats that could
have been accumulated by merge_stats across multiple yield steps."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(5, None),
),
):
stats = NodeExecutionStats()
stats += NodeExecutionStats(output_size=10, input_token_count=100)
stats += NodeExecutionStats(output_size=25, input_token_count=200)
assert stats.output_size == 35
assert stats.input_token_count == 300
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred-acc", "provider": "openai"},
"model": "gpt-4",
}
)
block = _make_block()
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
db_client.log_platform_cost.assert_awaited_once()
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.input_tokens == 300
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_amount"] == 300.0
@pytest.mark.asyncio
async def test_skips_cost_log_when_status_is_failed(self):
"""Manager only calls log_system_credential_cost on COMPLETED status.
This test verifies the guard condition `if status == COMPLETED` directly:
calling log_system_credential_cost only happens on success, never on
FAILED or ERROR executions.
"""
from backend.data.execution import ExecutionStatus
db_client = _make_db_client()
node_exec = _make_node_exec(
inputs={"credentials": {"id": "sys-cred", "provider": "openai"}}
)
block = _make_block()
stats = NodeExecutionStats(input_token_count=100)
# Simulate the manager guard: only call on COMPLETED
status = ExecutionStatus.FAILED
if status == ExecutionStatus.COMPLETED:
await log_system_credential_cost(node_exec, block, stats, db_client)
db_client.log_platform_cost.assert_not_awaited()

View File

@@ -1,42 +0,0 @@
-- CreateTable
CREATE TABLE "PlatformCostLog" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT,
"graphExecId" TEXT,
"nodeExecId" TEXT,
"graphId" TEXT,
"nodeId" TEXT,
"blockId" TEXT NOT NULL,
"blockName" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"credentialId" TEXT NOT NULL,
"costMicrodollars" BIGINT,
"inputTokens" INTEGER,
"outputTokens" INTEGER,
"dataSize" INTEGER,
"duration" DOUBLE PRECISION,
"model" TEXT,
"trackingType" TEXT,
"metadata" JSONB,
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "PlatformCostLog_userId_createdAt_idx" ON "PlatformCostLog"("userId", "createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_provider_createdAt_idx" ON "PlatformCostLog"("provider", "createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
-- CreateIndex
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
-- AddForeignKey
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -1,2 +0,0 @@
-- AlterTable
ALTER TABLE "PlatformCostLog" ADD COLUMN "trackingAmount" DOUBLE PRECISION;

View File

@@ -75,8 +75,6 @@ model User {
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
PlatformCostLogs PlatformCostLog[]
// OAuth Provider relations
OAuthApplications OAuthApplication[]
OAuthAuthorizationCodes OAuthAuthorizationCode[]
@@ -817,45 +815,6 @@ model CreditRefundRequest {
@@index([userId, transactionKey])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////// Platform Cost Tracking TABLES //////////////
////////////////////////////////////////////////////////////
model PlatformCostLog {
id String @id @default(uuid())
createdAt DateTime @default(now())
userId String?
User User? @relation(fields: [userId], references: [id], onDelete: SetNull)
graphExecId String?
nodeExecId String?
graphId String?
nodeId String?
blockId String
blockName String
provider String
credentialId String
// Cost in microdollars (1 USD = 1,000,000). Null if unknown.
costMicrodollars BigInt?
inputTokens Int?
outputTokens Int?
dataSize Int? // bytes
duration Float? // seconds
model String?
trackingType String? // e.g. "cost_usd", "tokens", "characters", "items", "per_run", "sandbox_seconds", "walltime_seconds"
trackingAmount Float? // Amount in the unit implied by trackingType
metadata Json?
@@index([userId, createdAt])
@@index([provider, createdAt])
@@index([createdAt])
@@index([graphExecId])
@@index([provider, trackingType])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// Store TABLES ///////////////////////////

View File

@@ -1,12 +1,6 @@
import { Sidebar } from "@/components/__legacy__/Sidebar";
import {
Users,
CurrencyDollar,
MagnifyingGlass,
Gauge,
Receipt,
FileText,
} from "@phosphor-icons/react/dist/ssr";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { Gauge } from "@phosphor-icons/react/dist/ssr";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -21,23 +15,18 @@ const sidebarLinkGroups = [
{
text: "User Spending",
href: "/admin/spending",
icon: <CurrencyDollar className="h-6 w-6" />,
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",
icon: <MagnifyingGlass className="h-6 w-6" />,
icon: <UserSearch className="h-6 w-6" />,
},
{
text: "Rate Limits",
href: "/admin/rate-limits",
icon: <Gauge className="h-6 w-6" />,
},
{
text: "Platform Costs",
href: "/admin/platform-costs",
icon: <Receipt className="h-6 w-6" />,
},
{
text: "Execution Analytics",
href: "/admin/execution-analytics",

View File

@@ -1,429 +0,0 @@
import {
render,
screen,
cleanup,
waitFor,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { PlatformCostContent } from "../components/PlatformCostContent";
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
import type { PlatformCostLogsResponse } from "@/app/api/__generated__/models/platformCostLogsResponse";
// Mock the generated Orval hooks so tests don't hit the network
const mockUseGetDashboard = vi.fn();
const mockUseGetLogs = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
useGetV2GetPlatformCostDashboard: (...args: unknown[]) =>
mockUseGetDashboard(...args),
useGetV2GetPlatformCostLogs: (...args: unknown[]) => mockUseGetLogs(...args),
}));
afterEach(() => {
cleanup();
mockUseGetDashboard.mockReset();
mockUseGetLogs.mockReset();
});
const emptyDashboard: PlatformCostDashboard = {
total_cost_microdollars: 0,
total_requests: 0,
total_users: 0,
by_provider: [],
by_user: [],
};
const emptyLogs: PlatformCostLogsResponse = {
logs: [],
pagination: {
current_page: 1,
page_size: 50,
total_items: 0,
total_pages: 0,
},
};
const dashboardWithData: PlatformCostDashboard = {
total_cost_microdollars: 5_000_000,
total_requests: 100,
total_users: 5,
by_provider: [
{
provider: "openai",
tracking_type: "tokens",
total_cost_microdollars: 3_000_000,
total_input_tokens: 50000,
total_output_tokens: 20000,
total_duration_seconds: 0,
request_count: 60,
},
{
provider: "google_maps",
tracking_type: "per_run",
total_cost_microdollars: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_duration_seconds: 0,
request_count: 40,
},
],
by_user: [
{
user_id: "user-1",
email: "alice@example.com",
total_cost_microdollars: 3_000_000,
total_input_tokens: 50000,
total_output_tokens: 20000,
request_count: 60,
},
],
};
const logsWithData: PlatformCostLogsResponse = {
logs: [
{
id: "log-1",
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
user_id: "user-1",
email: "alice@example.com",
graph_exec_id: "gx-123",
node_exec_id: "nx-456",
block_name: "LLMBlock",
provider: "openai",
tracking_type: "tokens",
cost_microdollars: 5000,
input_tokens: 100,
output_tokens: 50,
duration: 1.5,
model: "gpt-4",
},
],
pagination: {
current_page: 1,
page_size: 50,
total_items: 1,
total_pages: 1,
},
};
function renderComponent(searchParams = {}) {
return render(<PlatformCostContent searchParams={searchParams} />);
}
describe("PlatformCostContent", () => {
it("shows loading state initially", () => {
mockUseGetDashboard.mockReturnValue({ data: undefined, isLoading: true });
mockUseGetLogs.mockReturnValue({ data: undefined, isLoading: true });
renderComponent();
// Loading state renders Skeleton placeholders (animate-pulse divs) instead of content
expect(screen.queryByText("Loading...")).toBeNull();
// Summary cards and table content are not yet shown
expect(screen.queryByText("Known Cost")).toBeNull();
});
it("renders empty dashboard", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: emptyLogs,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
const zeroCostItems = screen.getAllByText("$0.0000");
expect(zeroCostItems.length).toBe(2);
expect(screen.getByText("No cost data yet")).toBeDefined();
});
it("renders dashboard with provider data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("$5.0000")).toBeDefined();
expect(screen.getByText("100")).toBeDefined();
expect(screen.getByText("5")).toBeDefined();
expect(screen.getByText("openai")).toBeDefined();
expect(screen.getByText("google_maps")).toBeDefined();
});
it("renders tracking type badges", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("tokens")).toBeDefined();
expect(screen.getByText("per_run")).toBeDefined();
});
it("shows error state on fetch failure", async () => {
mockUseGetDashboard.mockReturnValue({
data: undefined,
isLoading: false,
error: new Error("Network error"),
});
mockUseGetLogs.mockReturnValue({
data: undefined,
isLoading: false,
error: new Error("Network error"),
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Network error")).toBeDefined();
});
it("renders tab buttons", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("By Provider")).toBeDefined();
expect(screen.getByText("By User")).toBeDefined();
expect(screen.getByText("Raw Logs")).toBeDefined();
});
it("renders summary cards with correct labels", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
expect(screen.getByText("Total Requests")).toBeDefined();
expect(screen.getByText("Active Users")).toBeDefined();
});
it("renders filter inputs", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Start Date")).toBeDefined();
expect(screen.getByText("End Date")).toBeDefined();
expect(screen.getAllByText(/Provider/i).length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("User ID")).toBeDefined();
expect(screen.getByText("Apply")).toBeDefined();
});
it("renders by-user tab when specified", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("alice@example.com")).toBeDefined();
});
it("renders logs tab when specified", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("LLMBlock")).toBeDefined();
expect(screen.getByText("gpt-4")).toBeDefined();
});
it("renders no logs message when empty", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("No logs found")).toBeDefined();
});
it("shows pagination when multiple pages", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
const multiPageLogs: PlatformCostLogsResponse = {
logs: logsWithData.logs,
pagination: {
current_page: 1,
page_size: 50,
total_items: 200,
total_pages: 4,
},
};
mockUseGetLogs.mockReturnValue({
data: multiPageLogs,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Previous")).toBeDefined();
expect(screen.getByText("Next")).toBeDefined();
expect(screen.getByText(/Page 1 of 4/)).toBeDefined();
});
it("renders user table with unknown email", async () => {
const dashWithNullEmail: PlatformCostDashboard = {
...dashboardWithData,
by_user: [
{
user_id: "user-2",
email: null,
total_cost_microdollars: 1000,
total_input_tokens: 100,
total_output_tokens: 50,
request_count: 5,
},
],
};
mockUseGetDashboard.mockReturnValue({
data: dashWithNullEmail,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Unknown")).toBeDefined();
});
it("by-user tab content visible when tab=by-user param set", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("alice@example.com")).toBeDefined();
// overview tab content should not be visible
expect(screen.queryByText("openai")).toBeNull();
});
it("logs tab content visible when tab=logs param set", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("LLMBlock")).toBeDefined();
expect(screen.getByText("gpt-4")).toBeDefined();
});
it("renders log with null user as dash", async () => {
const logWithNullUser: PlatformCostLogsResponse = {
logs: [
{
id: "log-2",
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
user_id: null,
email: null,
graph_exec_id: null,
node_exec_id: null,
block_name: "copilot:SDK",
provider: "anthropic",
tracking_type: "cost_usd",
cost_microdollars: 15000,
input_tokens: null,
output_tokens: null,
duration: null,
model: "claude-opus-4-20250514",
},
],
pagination: {
current_page: 1,
page_size: 50,
total_items: 1,
total_pages: 1,
},
};
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logWithNullUser,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("copilot:SDK")).toBeDefined();
expect(screen.getByText("anthropic")).toBeDefined();
// null email + null user_id renders as "-" in the User column; multiple
// other cells (tokens, duration, session) also render "-", so use
// getAllByText to avoid the single-match constraint.
expect(screen.getAllByText("-").length).toBeGreaterThan(0);
});
});

View File

@@ -1,87 +0,0 @@
import { describe, expect, it, vi } from "vitest";
const mockGetDashboard = vi.fn();
const mockGetLogs = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
getV2GetPlatformCostDashboard: (...args: unknown[]) =>
mockGetDashboard(...args),
getV2GetPlatformCostLogs: (...args: unknown[]) => mockGetLogs(...args),
}));
import { getPlatformCostDashboard, getPlatformCostLogs } from "../actions";
describe("getPlatformCostDashboard", () => {
it("returns data on success", async () => {
const mockData = { total_cost_microdollars: 1000, total_requests: 5 };
mockGetDashboard.mockResolvedValue({ status: 200, data: mockData });
const result = await getPlatformCostDashboard();
expect(result).toEqual(mockData);
});
it("returns undefined on non-200", async () => {
mockGetDashboard.mockResolvedValue({ status: 401 });
const result = await getPlatformCostDashboard();
expect(result).toBeUndefined();
});
it("passes filter params to API", async () => {
mockGetDashboard.mockReset();
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
await getPlatformCostDashboard({
start: "2026-01-01T00:00:00",
end: "2026-06-01T00:00:00",
provider: "openai",
user_id: "user-1",
});
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
const params = mockGetDashboard.mock.calls[0][0];
expect(params.start).toBe("2026-01-01T00:00:00");
expect(params.end).toBe("2026-06-01T00:00:00");
expect(params.provider).toBe("openai");
expect(params.user_id).toBe("user-1");
});
it("passes undefined for empty filter strings", async () => {
mockGetDashboard.mockReset();
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
await getPlatformCostDashboard({
start: "",
provider: "",
user_id: "",
});
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
const params = mockGetDashboard.mock.calls[0][0];
expect(params.start).toBeUndefined();
expect(params.provider).toBeUndefined();
expect(params.user_id).toBeUndefined();
});
});
describe("getPlatformCostLogs", () => {
it("returns data on success", async () => {
const mockData = { logs: [], pagination: { current_page: 1 } };
mockGetLogs.mockResolvedValue({ status: 200, data: mockData });
const result = await getPlatformCostLogs();
expect(result).toEqual(mockData);
});
it("passes page and page_size", async () => {
mockGetLogs.mockReset();
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
await getPlatformCostLogs({ page: 3, page_size: 25 });
expect(mockGetLogs).toHaveBeenCalledTimes(1);
const params = mockGetLogs.mock.calls[0][0];
expect(params.page).toBe(3);
expect(params.page_size).toBe(25);
});
it("passes start date string through to API", async () => {
mockGetLogs.mockReset();
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
await getPlatformCostLogs({ start: "2026-03-01T00:00:00" });
expect(mockGetLogs).toHaveBeenCalledTimes(1);
const params = mockGetLogs.mock.calls[0][0];
expect(params.start).toBe("2026-03-01T00:00:00");
});
});

View File

@@ -1,300 +0,0 @@
import { describe, expect, it } from "vitest";
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
import {
toDateOrUndefined,
formatMicrodollars,
formatTokens,
formatDuration,
estimateCostForRow,
trackingValue,
toLocalInput,
toUtcIso,
} from "../helpers";
function makeRow(overrides: Partial<ProviderCostSummary>): ProviderCostSummary {
return {
provider: "openai",
tracking_type: null,
total_cost_microdollars: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_duration_seconds: 0,
request_count: 0,
...overrides,
};
}
describe("toDateOrUndefined", () => {
it("returns undefined for empty string", () => {
expect(toDateOrUndefined("")).toBeUndefined();
});
it("returns undefined for undefined", () => {
expect(toDateOrUndefined(undefined)).toBeUndefined();
});
it("returns undefined for invalid date string", () => {
expect(toDateOrUndefined("not-a-date")).toBeUndefined();
});
it("returns a Date for a valid ISO string", () => {
const result = toDateOrUndefined("2026-01-15T00:00:00Z");
expect(result).toBeInstanceOf(Date);
expect(result!.toISOString()).toBe("2026-01-15T00:00:00.000Z");
});
});
describe("formatMicrodollars", () => {
it("formats zero", () => {
expect(formatMicrodollars(0)).toBe("$0.0000");
});
it("formats a small amount", () => {
expect(formatMicrodollars(50_000)).toBe("$0.0500");
});
it("formats one dollar", () => {
expect(formatMicrodollars(1_000_000)).toBe("$1.0000");
});
});
describe("formatTokens", () => {
it("formats small numbers as-is", () => {
expect(formatTokens(500)).toBe("500");
});
it("formats thousands with K suffix", () => {
expect(formatTokens(1_500)).toBe("1.5K");
});
it("formats millions with M suffix", () => {
expect(formatTokens(2_500_000)).toBe("2.5M");
});
});
describe("formatDuration", () => {
it("formats seconds", () => {
expect(formatDuration(30)).toBe("30.0s");
});
it("formats minutes", () => {
expect(formatDuration(90)).toBe("1.5m");
});
it("formats hours", () => {
expect(formatDuration(5400)).toBe("1.5h");
});
});
describe("estimateCostForRow", () => {
it("returns microdollars directly for cost_usd tracking", () => {
const row = makeRow({
tracking_type: "cost_usd",
total_cost_microdollars: 500_000,
});
expect(estimateCostForRow(row, {})).toBe(500_000);
});
it("returns reported cost for token tracking when cost > 0", () => {
const row = makeRow({
tracking_type: "tokens",
total_cost_microdollars: 100_000,
total_input_tokens: 1000,
total_output_tokens: 500,
});
expect(estimateCostForRow(row, {})).toBe(100_000);
});
it("estimates cost from default rate for token tracking with zero cost", () => {
const row = makeRow({
provider: "openai",
tracking_type: "tokens",
total_cost_microdollars: 0,
total_input_tokens: 500,
total_output_tokens: 500,
});
// 1000 tokens / 1000 * 0.005 USD * 1_000_000 = 5000
expect(estimateCostForRow(row, {})).toBe(5000);
});
it("returns null for unknown token provider with zero cost", () => {
const row = makeRow({
provider: "unknown_provider",
tracking_type: "tokens",
total_cost_microdollars: 0,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("uses per-run override when provided", () => {
const row = makeRow({
provider: "google_maps",
tracking_type: "per_run",
request_count: 10,
});
// override = 0.05 * 10 * 1_000_000 = 500_000
expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe(
500_000,
);
});
it("uses default per-run cost when no override", () => {
const row = makeRow({
provider: "google_maps",
tracking_type: null,
request_count: 5,
});
// 0.032 * 5 * 1_000_000 = 160_000
expect(estimateCostForRow(row, {})).toBe(160_000);
});
it("returns null for unknown per_run provider", () => {
const row = makeRow({
provider: "totally_unknown",
tracking_type: "per_run",
request_count: 3,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("returns null for duration tracking with no rate and no cost", () => {
const row = makeRow({
provider: "openai",
tracking_type: "duration_seconds",
total_cost_microdollars: 0,
total_duration_seconds: 100,
});
expect(estimateCostForRow(row, {})).toBeNull();
});
it("estimates cost from default rate for characters tracking", () => {
const row = makeRow({
provider: "elevenlabs",
tracking_type: "characters",
total_cost_microdollars: 0,
total_tracking_amount: 2000,
});
// 2000 chars / 1000 * 0.18 USD * 1_000_000 = 360_000
expect(estimateCostForRow(row, {})).toBe(360_000);
});
it("estimates cost from default rate for items tracking", () => {
const row = makeRow({
provider: "apollo",
tracking_type: "items",
total_cost_microdollars: 0,
total_tracking_amount: 50,
});
// 50 * 0.02 * 1_000_000 = 1_000_000
expect(estimateCostForRow(row, {})).toBe(1_000_000);
});
it("estimates cost from default rate for duration tracking", () => {
const row = makeRow({
provider: "e2b",
tracking_type: "sandbox_seconds",
total_cost_microdollars: 0,
total_duration_seconds: 1_000_000,
});
// 1_000_000 * 0.000014 * 1_000_000 = 14_000_000
expect(estimateCostForRow(row, {})).toBe(14_000_000);
});
});
describe("trackingValue", () => {
it("returns formatted microdollars for cost_usd", () => {
const row = makeRow({
tracking_type: "cost_usd",
total_cost_microdollars: 1_000_000,
});
expect(trackingValue(row)).toBe("$1.0000");
});
it("returns formatted token count for tokens", () => {
const row = makeRow({
tracking_type: "tokens",
total_input_tokens: 500,
total_output_tokens: 500,
});
expect(trackingValue(row)).toBe("1.0K tokens");
});
it("returns formatted duration for sandbox_seconds", () => {
const row = makeRow({
tracking_type: "sandbox_seconds",
total_duration_seconds: 120,
});
expect(trackingValue(row)).toBe("2.0m");
});
it("returns run count for per_run (default tracking)", () => {
const row = makeRow({
tracking_type: null,
request_count: 42,
});
expect(trackingValue(row)).toBe("42 runs");
});
it("returns formatted character count for characters tracking", () => {
const row = makeRow({
tracking_type: "characters",
total_tracking_amount: 2500,
});
expect(trackingValue(row)).toBe("2.5K chars");
});
it("returns formatted item count for items tracking", () => {
const row = makeRow({
tracking_type: "items",
total_tracking_amount: 1234,
});
expect(trackingValue(row)).toBe("1,234 items");
});
it("returns formatted duration for sandbox_seconds", () => {
const row = makeRow({
tracking_type: "sandbox_seconds",
total_duration_seconds: 7200,
});
expect(trackingValue(row)).toBe("2.0h");
});
it("returns formatted duration for walltime_seconds", () => {
const row = makeRow({
tracking_type: "walltime_seconds",
total_duration_seconds: 45,
});
expect(trackingValue(row)).toBe("45.0s");
});
});
describe("toLocalInput", () => {
it("returns empty string for empty input", () => {
expect(toLocalInput("")).toBe("");
});
it("returns empty string for invalid ISO", () => {
expect(toLocalInput("not-a-date")).toBe("");
});
it("converts UTC ISO to local datetime-local format", () => {
const result = toLocalInput("2026-01-15T12:30:00Z");
// Format should be YYYY-MM-DDTHH:mm
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$/);
});
});
describe("toUtcIso", () => {
it("returns empty string for empty input", () => {
expect(toUtcIso("")).toBe("");
});
it("returns empty string for invalid local time", () => {
expect(toUtcIso("not-a-date")).toBe("");
});
it("converts local datetime-local to ISO string", () => {
const result = toUtcIso("2026-01-15T12:30");
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/);
});
});

View File

@@ -1,45 +0,0 @@
import {
getV2GetPlatformCostDashboard,
getV2GetPlatformCostLogs,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
// Backend expects ISO datetime strings. The generated client's URL builder
// calls .toString() on values, which for Date objects produces the human
// "Tue Mar 31 2026 22:00:00 GMT+0000 (Coordinated Universal Time)" format
// that FastAPI rejects with 422. We already pass UTC ISO from the URL, so
// forward the raw strings through the `as unknown as Date` cast to match
// the generated typing without triggering Date.toString().
export async function getPlatformCostDashboard(params?: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
}) {
const response = await getV2GetPlatformCostDashboard({
start: (params?.start || undefined) as unknown as Date | undefined,
end: (params?.end || undefined) as unknown as Date | undefined,
provider: params?.provider || undefined,
user_id: params?.user_id || undefined,
});
return okData(response);
}
export async function getPlatformCostLogs(params?: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: number;
page_size?: number;
}) {
const response = await getV2GetPlatformCostLogs({
start: (params?.start || undefined) as unknown as Date | undefined,
end: (params?.end || undefined) as unknown as Date | undefined,
provider: params?.provider || undefined,
user_id: params?.user_id || undefined,
page: params?.page,
page_size: params?.page_size,
});
return okData(response);
}

View File

@@ -1,140 +0,0 @@
import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow";
import type { Pagination } from "@/app/api/__generated__/models/pagination";
import { formatDuration, formatMicrodollars, formatTokens } from "../helpers";
import { TrackingBadge } from "./TrackingBadge";
function formatLogDate(value: unknown): string {
if (value instanceof Date) return value.toLocaleString();
if (typeof value === "string" || typeof value === "number")
return new Date(value).toLocaleString();
return "-";
}
interface Props {
logs: CostLogRow[];
pagination: Pagination | null;
onPageChange: (page: number) => void;
}
function LogsTable({ logs, pagination, onPageChange }: Props) {
return (
<div className="flex flex-col gap-4">
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-3 py-3">
Time
</th>
<th scope="col" className="px-3 py-3">
User
</th>
<th scope="col" className="px-3 py-3">
Block
</th>
<th scope="col" className="px-3 py-3">
Provider
</th>
<th scope="col" className="px-3 py-3">
Type
</th>
<th scope="col" className="px-3 py-3">
Model
</th>
<th scope="col" className="px-3 py-3 text-right">
Cost
</th>
<th scope="col" className="px-3 py-3 text-right">
Tokens
</th>
<th scope="col" className="px-3 py-3 text-right">
Duration
</th>
<th scope="col" className="px-3 py-3">
Session
</th>
</tr>
</thead>
<tbody>
{logs.map((log) => (
<tr key={log.id} className="border-b hover:bg-muted">
<td className="whitespace-nowrap px-3 py-2 text-xs">
{formatLogDate(log.created_at)}
</td>
<td className="px-3 py-2 text-xs">
{log.email ||
(log.user_id ? String(log.user_id).slice(0, 8) : "-")}
</td>
<td className="px-3 py-2 text-xs font-medium">
{log.block_name}
</td>
<td className="px-3 py-2 text-xs">{log.provider}</td>
<td className="px-3 py-2 text-xs">
<TrackingBadge trackingType={log.tracking_type} />
</td>
<td className="px-3 py-2 text-xs">{log.model || "-"}</td>
<td className="px-3 py-2 text-right text-xs">
{log.cost_microdollars != null
? formatMicrodollars(Number(log.cost_microdollars))
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.input_tokens != null || log.output_tokens != null
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.duration != null
? formatDuration(Number(log.duration))
: "-"}
</td>
<td className="px-3 py-2 text-xs text-muted-foreground">
{log.graph_exec_id
? String(log.graph_exec_id).slice(0, 8)
: "-"}
</td>
</tr>
))}
{logs.length === 0 && (
<tr>
<td
colSpan={10}
className="px-4 py-8 text-center text-muted-foreground"
>
No logs found
</td>
</tr>
)}
</tbody>
</table>
</div>
{pagination && pagination.total_pages > 1 && (
<div className="flex items-center justify-between px-4">
<span className="text-sm text-muted-foreground">
Page {pagination.current_page} of {pagination.total_pages} (
{pagination.total_items} total)
</span>
<div className="flex gap-2">
<button
disabled={pagination.current_page <= 1}
onClick={() => onPageChange(pagination.current_page - 1)}
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
>
Previous
</button>
<button
disabled={pagination.current_page >= pagination.total_pages}
onClick={() => onPageChange(pagination.current_page + 1)}
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
>
Next
</button>
</div>
</div>
)}
</div>
);
}
export { LogsTable };

View File

@@ -1,233 +0,0 @@
"use client";
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { formatMicrodollars } from "../helpers";
import { SummaryCard } from "./SummaryCard";
import { ProviderTable } from "./ProviderTable";
import { UserTable } from "./UserTable";
import { LogsTable } from "./LogsTable";
import { usePlatformCostContent } from "./usePlatformCostContent";
interface Props {
searchParams: {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
};
}
function PlatformCostContent({ searchParams }: Props) {
const {
dashboard,
logs,
pagination,
loading,
error,
totalEstimatedCost,
tab,
startInput,
setStartInput,
endInput,
setEndInput,
providerInput,
setProviderInput,
userInput,
setUserInput,
rateOverrides,
handleRateOverride,
updateUrl,
handleFilter,
} = usePlatformCostContent(searchParams);
return (
<div className="flex flex-col gap-6">
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
<div className="flex flex-col gap-1">
<label htmlFor="start-date" className="text-sm text-muted-foreground">
Start Date <span className="text-xs">(local time)</span>
</label>
<input
id="start-date"
type="datetime-local"
className="rounded border px-3 py-1.5 text-sm"
value={startInput}
onChange={(e) => setStartInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label htmlFor="end-date" className="text-sm text-muted-foreground">
End Date <span className="text-xs">(local time)</span>
</label>
<input
id="end-date"
type="datetime-local"
className="rounded border px-3 py-1.5 text-sm"
value={endInput}
onChange={(e) => setEndInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label
htmlFor="provider-filter"
className="text-sm text-muted-foreground"
>
Provider
</label>
<input
id="provider-filter"
type="text"
placeholder="e.g. openai"
className="rounded border px-3 py-1.5 text-sm"
value={providerInput}
onChange={(e) => setProviderInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label
htmlFor="user-id-filter"
className="text-sm text-muted-foreground"
>
User ID
</label>
<input
id="user-id-filter"
type="text"
placeholder="Filter by user"
className="rounded border px-3 py-1.5 text-sm"
value={userInput}
onChange={(e) => setUserInput(e.target.value)}
/>
</div>
<button
onClick={handleFilter}
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
>
Apply
</button>
<button
onClick={() => {
setStartInput("");
setEndInput("");
setProviderInput("");
setUserInput("");
updateUrl({
start: "",
end: "",
provider: "",
user_id: "",
page: "1",
});
}}
className="rounded border px-4 py-1.5 text-sm hover:bg-muted"
>
Clear
</button>
</div>
{error && (
<Alert variant="error">
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
{loading ? (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
{[...Array(4)].map((_, i) => (
<Skeleton key={i} className="h-20 rounded-lg" />
))}
</div>
<Skeleton className="h-8 w-48 rounded" />
<Skeleton className="h-64 rounded-lg" />
</div>
) : (
<>
{dashboard && (
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
<SummaryCard
label="Known Cost"
value={formatMicrodollars(dashboard.total_cost_microdollars)}
subtitle="From providers that report USD cost"
/>
<SummaryCard
label="Estimated Total"
value={formatMicrodollars(totalEstimatedCost)}
subtitle="Including per-run cost estimates"
/>
<SummaryCard
label="Total Requests"
value={dashboard.total_requests.toLocaleString()}
/>
<SummaryCard
label="Active Users"
value={dashboard.total_users.toLocaleString()}
/>
</div>
)}
<div
role="tablist"
aria-label="Cost view tabs"
className="flex gap-2 border-b"
>
{["overview", "by-user", "logs"].map((t) => (
<button
key={t}
id={`tab-${t}`}
role="tab"
aria-selected={tab === t}
aria-controls={`tabpanel-${t}`}
onClick={() => updateUrl({ tab: t, page: "1" })}
className={`px-4 py-2 text-sm font-medium ${tab === t ? "border-b-2 border-primary text-primary" : "text-muted-foreground hover:text-foreground"}`}
>
{t === "overview"
? "By Provider"
: t === "by-user"
? "By User"
: "Raw Logs"}
</button>
))}
</div>
{tab === "overview" && dashboard && (
<div
role="tabpanel"
id="tabpanel-overview"
aria-labelledby="tab-overview"
>
<ProviderTable
data={dashboard.by_provider}
rateOverrides={rateOverrides}
onRateOverride={handleRateOverride}
/>
</div>
)}
{tab === "by-user" && dashboard && (
<div
role="tabpanel"
id="tabpanel-by-user"
aria-labelledby="tab-by-user"
>
<UserTable data={dashboard.by_user} />
</div>
)}
{tab === "logs" && (
<div role="tabpanel" id="tabpanel-logs" aria-labelledby="tab-logs">
<LogsTable
logs={logs}
pagination={pagination}
onPageChange={(p) => updateUrl({ page: p.toString() })}
/>
</div>
)}
</>
)}
</div>
);
}
export { PlatformCostContent };

View File

@@ -1,131 +0,0 @@
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
import {
defaultRateFor,
estimateCostForRow,
formatMicrodollars,
rateKey,
rateUnitLabel,
trackingValue,
} from "../helpers";
import { TrackingBadge } from "./TrackingBadge";
interface Props {
data: ProviderCostSummary[];
rateOverrides: Record<string, number>;
onRateOverride: (key: string, val: number | null) => void;
}
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
return (
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-4 py-3">
Provider
</th>
<th scope="col" className="px-4 py-3">
Type
</th>
<th scope="col" className="px-4 py-3 text-right">
Usage
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
<th scope="col" className="px-4 py-3 text-right">
Known Cost
</th>
<th scope="col" className="px-4 py-3 text-right">
Est. Cost
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Per-session only"
>
Rate <span className="text-[10px] font-normal">(unsaved)</span>
</th>
</tr>
</thead>
<tbody>
{data.map((row) => {
const est = estimateCostForRow(row, rateOverrides);
const tt = row.tracking_type || "per_run";
// For cost_usd rows the provider reports USD directly so rate
// input doesn't apply; otherwise show an editable input.
const showRateInput = tt !== "cost_usd";
const key = rateKey(row.provider, tt);
const fallback = defaultRateFor(row.provider, tt);
const currentRate = rateOverrides[key] ?? fallback;
return (
<tr key={key} className="border-b hover:bg-muted">
<td className="px-4 py-3 font-medium">{row.provider}</td>
<td className="px-4 py-3">
<TrackingBadge trackingType={row.tracking_type} />
</td>
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
<td className="px-4 py-3 text-right">
{row.total_cost_microdollars > 0
? formatMicrodollars(row.total_cost_microdollars)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{est !== null ? (
formatMicrodollars(est)
) : (
<span className="text-muted-foreground">-</span>
)}
</td>
<td className="px-4 py-2 text-right">
{showRateInput ? (
<div className="flex items-center justify-end gap-1">
<input
type="number"
step="0.0001"
min="0"
aria-label={`Rate for ${row.provider} (${tt})`}
className="w-24 rounded border px-2 py-1 text-right text-xs"
placeholder={fallback !== null ? String(fallback) : "0"}
value={currentRate ?? ""}
onChange={(e) => {
const val = parseFloat(e.target.value);
if (!isNaN(val)) onRateOverride(key, val);
else if (e.target.value === "")
onRateOverride(key, null);
}}
/>
<span
className="text-[10px] text-muted-foreground"
title={rateUnitLabel(tt)}
>
{rateUnitLabel(tt)}
</span>
</div>
) : (
<span className="text-xs text-muted-foreground">auto</span>
)}
</td>
</tr>
);
})}
{data.length === 0 && (
<tr>
<td
colSpan={7}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet
</td>
</tr>
)}
</tbody>
</table>
</div>
);
}
export { ProviderTable };

View File

@@ -1,19 +0,0 @@
interface Props {
label: string;
value: string;
subtitle?: string;
}
function SummaryCard({ label, value, subtitle }: Props) {
return (
<div className="rounded-lg border p-4">
<div className="text-sm text-muted-foreground">{label}</div>
<div className="text-2xl font-bold">{value}</div>
{subtitle && (
<div className="mt-1 text-xs text-muted-foreground">{subtitle}</div>
)}
</div>
);
}
export { SummaryCard };

View File

@@ -1,25 +0,0 @@
function TrackingBadge({
trackingType,
}: {
trackingType: string | null | undefined;
}) {
const colors: Record<string, string> = {
cost_usd: "bg-green-500/10 text-green-700",
tokens: "bg-blue-500/10 text-blue-700",
characters: "bg-purple-500/10 text-purple-700",
sandbox_seconds: "bg-orange-500/10 text-orange-700",
walltime_seconds: "bg-orange-500/10 text-orange-700",
items: "bg-pink-500/10 text-pink-700",
per_run: "bg-muted text-muted-foreground",
};
const label = trackingType || "per_run";
return (
<span
className={`inline-block rounded px-1.5 py-0.5 text-[10px] font-medium ${colors[label] || colors.per_run}`}
>
{label}
</span>
);
}
export { TrackingBadge };

View File

@@ -1,75 +0,0 @@
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
import { formatMicrodollars, formatTokens } from "../helpers";
interface Props {
data: PlatformCostDashboard["by_user"];
}
function UserTable({ data }: Props) {
return (
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead className="border-b text-xs uppercase text-muted-foreground">
<tr>
<th scope="col" className="px-4 py-3">
User
</th>
<th scope="col" className="px-4 py-3 text-right">
Known Cost
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
<th scope="col" className="px-4 py-3 text-right">
Input Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Output Tokens
</th>
</tr>
</thead>
<tbody>
{data.map((row, idx) => (
<tr
key={row.user_id ?? `unknown-${idx}`}
className="border-b hover:bg-muted"
>
<td className="px-4 py-3">
<div className="font-medium">{row.email || "Unknown"}</div>
<div className="text-xs text-muted-foreground">
{row.user_id}
</div>
</td>
<td className="px-4 py-3 text-right">
{row.total_cost_microdollars > 0
? formatMicrodollars(row.total_cost_microdollars)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
<td className="px-4 py-3 text-right">
{formatTokens(row.total_input_tokens)}
</td>
<td className="px-4 py-3 text-right">
{formatTokens(row.total_output_tokens)}
</td>
</tr>
))}
{data.length === 0 && (
<tr>
<td
colSpan={5}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet
</td>
</tr>
)}
</tbody>
</table>
</div>
);
}
export { UserTable };

View File

@@ -1,136 +0,0 @@
"use client";
import { useRouter, useSearchParams } from "next/navigation";
import { useState } from "react";
import {
useGetV2GetPlatformCostDashboard,
useGetV2GetPlatformCostLogs,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers";
interface InitialSearchParams {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
}
export function usePlatformCostContent(searchParams: InitialSearchParams) {
const router = useRouter();
const urlParams = useSearchParams();
const tab = urlParams.get("tab") || searchParams.tab || "overview";
const page = parseInt(urlParams.get("page") || searchParams.page || "1", 10);
const startDate = urlParams.get("start") || searchParams.start || "";
const endDate = urlParams.get("end") || searchParams.end || "";
const providerFilter =
urlParams.get("provider") || searchParams.provider || "";
const userFilter = urlParams.get("user_id") || searchParams.user_id || "";
const [startInput, setStartInput] = useState(toLocalInput(startDate));
const [endInput, setEndInput] = useState(toLocalInput(endDate));
const [providerInput, setProviderInput] = useState(providerFilter);
const [userInput, setUserInput] = useState(userFilter);
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
{},
);
// Pass ISO date strings through `as unknown as Date` so Orval's URL builder
// forwards them as-is. Date.toString() produces a format FastAPI rejects;
// strings pass through .toString() unchanged.
const filterParams = {
start: (startDate || undefined) as unknown as Date | undefined,
end: (endDate || undefined) as unknown as Date | undefined,
provider: providerFilter || undefined,
user_id: userFilter || undefined,
};
const {
data: dashboard,
isLoading: dashLoading,
error: dashError,
} = useGetV2GetPlatformCostDashboard(filterParams, {
query: { select: okData },
});
const {
data: logsResponse,
isLoading: logsLoading,
error: logsError,
} = useGetV2GetPlatformCostLogs(
{ ...filterParams, page, page_size: 50 },
{ query: { select: okData } },
);
const loading = dashLoading || logsLoading;
const error = dashError
? dashError instanceof Error
? dashError.message
: "Failed to load dashboard"
: logsError
? logsError instanceof Error
? logsError.message
: "Failed to load logs"
: null;
function updateUrl(overrides: Record<string, string>) {
const params = new URLSearchParams(urlParams.toString());
for (const [k, v] of Object.entries(overrides)) {
if (v) params.set(k, v);
else params.delete(k);
}
router.push(`/admin/platform-costs?${params.toString()}`);
}
function handleFilter() {
updateUrl({
start: toUtcIso(startInput),
end: toUtcIso(endInput),
provider: providerInput,
user_id: userInput,
page: "1",
});
}
function handleRateOverride(key: string, val: number | null) {
setRateOverrides((prev) => {
if (val === null) {
const { [key]: _, ...rest } = prev;
return rest;
}
return { ...prev, [key]: val };
});
}
const totalEstimatedCost =
dashboard?.by_provider.reduce((sum, row) => {
const est = estimateCostForRow(row, rateOverrides);
return sum + (est ?? 0);
}, 0) ?? 0;
return {
dashboard: dashboard ?? null,
logs: logsResponse?.logs ?? [],
pagination: logsResponse?.pagination ?? null,
loading,
error,
totalEstimatedCost,
tab,
page,
startInput,
setStartInput,
endInput,
setEndInput,
providerInput,
setProviderInput,
userInput,
setUserInput,
rateOverrides,
handleRateOverride,
updateUrl,
handleFilter,
};
}

View File

@@ -1,204 +0,0 @@
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
const MICRODOLLARS_PER_USD = 1_000_000;
// Per-request cost estimates (USD) for providers billed per API call.
export const DEFAULT_COST_PER_RUN: Record<string, number> = {
google_maps: 0.032, // $0.032/request - Google Maps Places API
ideogram: 0.08, // $0.08/image - Ideogram standard generation
nvidia: 0.0, // Free tier - NVIDIA NIM deepfake detection
screenshotone: 0.01, // ~$0.01/screenshot - ScreenshotOne starter
zerobounce: 0.008, // $0.008/validation - ZeroBounce
mem0: 0.01, // ~$0.01/request - Mem0
openweathermap: 0.0, // Free tier
webshare_proxy: 0.0, // Flat subscription
enrichlayer: 0.1, // ~$0.10/profile lookup
jina: 0.0, // Free tier
};
export const DEFAULT_COST_PER_1K_TOKENS: Record<string, number> = {
openai: 0.005,
anthropic: 0.008,
groq: 0.0003,
ollama: 0.0,
aiml_api: 0.005,
llama_api: 0.003,
v0: 0.005,
};
// Per-character rates (USD / 1K characters) for TTS providers.
export const DEFAULT_COST_PER_1K_CHARS: Record<string, number> = {
unreal_speech: 0.008, // ~$8/1M chars on Starter
elevenlabs: 0.18, // ~$0.18/1K chars on Starter
d_id: 0.04, // ~$0.04/1K chars estimated
};
// Per-item rates (USD / item) for item-count billed APIs.
export const DEFAULT_COST_PER_ITEM: Record<string, number> = {
google_maps: 0.017, // avg of $0.032 nearby + ~$0.015 detail enrich
apollo: 0.02, // ~$0.02/contact on low-volume tiers
smartlead: 0.001, // ~$0.001/lead added
};
// Per-second rates (USD / second) for duration-billed providers.
export const DEFAULT_COST_PER_SECOND: Record<string, number> = {
e2b: 0.000014, // $0.000014/sec (2-core sandbox)
fal: 0.0005, // varies by model, conservative
replicate: 0.001, // varies by hardware
revid: 0.01, // per-second of video
};
export function toDateOrUndefined(val?: string): Date | undefined {
if (!val) return undefined;
const d = new Date(val);
return isNaN(d.getTime()) ? undefined : d;
}
export function formatMicrodollars(microdollars: number) {
return `$${(microdollars / MICRODOLLARS_PER_USD).toFixed(4)}`;
}
export function formatTokens(tokens: number) {
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`;
return tokens.toString();
}
export function formatDuration(seconds: number) {
if (seconds >= 3600) return `${(seconds / 3600).toFixed(1)}h`;
if (seconds >= 60) return `${(seconds / 60).toFixed(1)}m`;
return `${seconds.toFixed(1)}s`;
}
// Unit label for each tracking type — what the rate input represents.
export function rateUnitLabel(trackingType: string | null | undefined): string {
switch (trackingType) {
case "tokens":
return "$/1K tokens";
case "characters":
return "$/1K chars";
case "items":
return "$/item";
case "sandbox_seconds":
case "walltime_seconds":
return "$/second";
case "per_run":
return "$/run";
default:
return "";
}
}
// Default rate for a (provider, tracking_type) pair.
export function defaultRateFor(
provider: string,
trackingType: string | null | undefined,
): number | null {
switch (trackingType) {
case "tokens":
return DEFAULT_COST_PER_1K_TOKENS[provider] ?? null;
case "characters":
return DEFAULT_COST_PER_1K_CHARS[provider] ?? null;
case "items":
return DEFAULT_COST_PER_ITEM[provider] ?? null;
case "sandbox_seconds":
case "walltime_seconds":
return DEFAULT_COST_PER_SECOND[provider] ?? null;
case "per_run":
return DEFAULT_COST_PER_RUN[provider] ?? null;
default:
return null;
}
}
// Overrides are keyed on `${provider}:${tracking_type}` since the same
// provider can have multiple rows with different billing models.
export function rateKey(
provider: string,
trackingType: string | null | undefined,
): string {
return `${provider}:${trackingType ?? "per_run"}`;
}
export function estimateCostForRow(
row: ProviderCostSummary,
rateOverrides: Record<string, number>,
) {
const tt = row.tracking_type || "per_run";
// Providers that report USD directly: use known cost.
if (tt === "cost_usd") return row.total_cost_microdollars;
// Prefer the real USD the provider reported if any, but only for token paths
// where OpenRouter piggybacks on the tokens row via x-total-cost.
if (tt === "tokens" && row.total_cost_microdollars > 0) {
return row.total_cost_microdollars;
}
const rate =
rateOverrides[rateKey(row.provider, tt)] ??
defaultRateFor(row.provider, tt);
if (rate === null || rate === undefined) return null;
// Compute the amount for this tracking type, then multiply by rate.
let amount: number;
switch (tt) {
case "tokens":
// Rate is per-1K tokens.
amount = (row.total_input_tokens + row.total_output_tokens) / 1000;
break;
case "characters":
// Rate is per-1K chars. trackingAmount aggregates char counts.
amount = (row.total_tracking_amount || 0) / 1000;
break;
case "items":
amount = row.total_tracking_amount || 0;
break;
case "sandbox_seconds":
case "walltime_seconds":
amount = row.total_duration_seconds || 0;
break;
case "per_run":
amount = row.request_count;
break;
default:
return row.total_cost_microdollars > 0
? row.total_cost_microdollars
: null;
}
return Math.round(rate * amount * MICRODOLLARS_PER_USD);
}
export function trackingValue(row: ProviderCostSummary) {
const tt = row.tracking_type || "per_run";
if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars);
if (tt === "tokens") {
const tokens = row.total_input_tokens + row.total_output_tokens;
return `${formatTokens(tokens)} tokens`;
}
if (tt === "sandbox_seconds" || tt === "walltime_seconds")
return formatDuration(row.total_duration_seconds || 0);
if (tt === "characters")
return `${formatTokens(Math.round(row.total_tracking_amount || 0))} chars`;
if (tt === "items")
return `${Math.round(row.total_tracking_amount || 0).toLocaleString()} items`;
return `${row.request_count.toLocaleString()} runs`;
}
// URL holds UTC ISO; datetime-local inputs need local "YYYY-MM-DDTHH:mm".
export function toLocalInput(iso: string) {
if (!iso) return "";
const d = new Date(iso);
if (isNaN(d.getTime())) return "";
const pad = (n: number) => String(n).padStart(2, "0");
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())}T${pad(d.getHours())}:${pad(d.getMinutes())}`;
}
// datetime-local emits naive local time; convert to UTC ISO so the
// backend filter window matches what the admin sees in their browser.
export function toUtcIso(local: string) {
if (!local) return "";
const d = new Date(local);
return isNaN(d.getTime()) ? "" : d.toISOString();
}

View File

@@ -1,50 +0,0 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
import { PlatformCostContent } from "./components/PlatformCostContent";
type SearchParams = {
start?: string;
end?: string;
provider?: string;
user_id?: string;
page?: string;
tab?: string;
};
function PlatformCostDashboard({
searchParams,
}: {
searchParams: SearchParams;
}) {
return (
<div className="mx-auto p-6">
<div className="flex flex-col gap-4">
<div>
<h1 className="text-3xl font-bold">Platform Costs</h1>
<p className="text-muted-foreground">
Track real API costs incurred by system credentials across providers
</p>
</div>
<Suspense
key={JSON.stringify(searchParams)}
fallback={
<div className="py-10 text-center">Loading cost data...</div>
}
>
<PlatformCostContent searchParams={searchParams} />
</Suspense>
</div>
</div>
);
}
export default async function PlatformCostDashboardPage({
searchParams,
}: {
searchParams: Promise<SearchParams>;
}) {
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedDashboard = await withAdminAccess(PlatformCostDashboard);
return <ProtectedDashboard searchParams={await searchParams} />;
}

View File

@@ -7,179 +7,6 @@
"version": "0.1"
},
"paths": {
"/api/admin/platform-costs/dashboard": {
"get": {
"tags": ["v2", "admin", "platform-cost", "admin"],
"summary": "Get Platform Cost Dashboard",
"operationId": "getV2Get platform cost dashboard",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "start",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{ "type": "string", "format": "date-time" },
{ "type": "null" }
],
"title": "Start"
}
},
{
"name": "end",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{ "type": "string", "format": "date-time" },
{ "type": "null" }
],
"title": "End"
}
},
{
"name": "provider",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Provider"
}
},
{
"name": "user_id",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/PlatformCostDashboard"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/admin/platform-costs/logs": {
"get": {
"tags": ["v2", "admin", "platform-cost", "admin"],
"summary": "Get Platform Cost Logs",
"operationId": "getV2Get platform cost logs",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "start",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{ "type": "string", "format": "date-time" },
{ "type": "null" }
],
"title": "Start"
}
},
{
"name": "end",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{ "type": "string", "format": "date-time" },
{ "type": "null" }
],
"title": "End"
}
},
{
"name": "provider",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Provider"
}
},
{
"name": "user_id",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
}
},
{
"name": "page",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"minimum": 1,
"default": 1,
"title": "Page"
}
},
{
"name": "page_size",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"maximum": 200,
"minimum": 1,
"default": 50,
"title": "Page Size"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/PlatformCostLogsResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/analytics/log_raw_analytics": {
"post": {
"tags": ["analytics"],
@@ -8906,61 +8733,6 @@
],
"title": "ContentType"
},
"CostLogRow": {
"properties": {
"id": { "type": "string", "title": "Id" },
"created_at": {
"type": "string",
"format": "date-time",
"title": "Created At"
},
"user_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
},
"email": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Email"
},
"graph_exec_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Graph Exec Id"
},
"node_exec_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Node Exec Id"
},
"block_name": { "type": "string", "title": "Block Name" },
"provider": { "type": "string", "title": "Provider" },
"tracking_type": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Tracking Type"
},
"cost_microdollars": {
"anyOf": [{ "type": "integer" }, { "type": "null" }],
"title": "Cost Microdollars"
},
"input_tokens": {
"anyOf": [{ "type": "integer" }, { "type": "null" }],
"title": "Input Tokens"
},
"output_tokens": {
"anyOf": [{ "type": "integer" }, { "type": "null" }],
"title": "Output Tokens"
},
"duration": {
"anyOf": [{ "type": "number" }, { "type": "null" }],
"title": "Duration"
},
"model": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Model"
}
},
"type": "object",
"required": ["id", "created_at", "block_name", "provider"],
"title": "CostLogRow"
},
"CountResponse": {
"properties": {
"all_blocks": { "type": "integer", "title": "All Blocks" },
@@ -11892,48 +11664,6 @@
"title": "PendingHumanReviewModel",
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
},
"PlatformCostDashboard": {
"properties": {
"by_provider": {
"items": { "$ref": "#/components/schemas/ProviderCostSummary" },
"type": "array",
"title": "By Provider"
},
"by_user": {
"items": { "$ref": "#/components/schemas/UserCostSummary" },
"type": "array",
"title": "By User"
},
"total_cost_microdollars": {
"type": "integer",
"title": "Total Cost Microdollars"
},
"total_requests": { "type": "integer", "title": "Total Requests" },
"total_users": { "type": "integer", "title": "Total Users" }
},
"type": "object",
"required": [
"by_provider",
"by_user",
"total_cost_microdollars",
"total_requests",
"total_users"
],
"title": "PlatformCostDashboard"
},
"PlatformCostLogsResponse": {
"properties": {
"logs": {
"items": { "$ref": "#/components/schemas/CostLogRow" },
"type": "array",
"title": "Logs"
},
"pagination": { "$ref": "#/components/schemas/Pagination" }
},
"type": "object",
"required": ["logs", "pagination"],
"title": "PlatformCostLogsResponse"
},
"PostmarkBounceEnum": {
"type": "integer",
"enum": [
@@ -12328,47 +12058,6 @@
"title": "ProviderConstants",
"description": "Model that exposes all provider names as a constant in the OpenAPI schema.\nThis is designed to be converted by Orval into a TypeScript constant."
},
"ProviderCostSummary": {
"properties": {
"provider": { "type": "string", "title": "Provider" },
"tracking_type": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Tracking Type"
},
"total_cost_microdollars": {
"type": "integer",
"title": "Total Cost Microdollars"
},
"total_input_tokens": {
"type": "integer",
"title": "Total Input Tokens"
},
"total_output_tokens": {
"type": "integer",
"title": "Total Output Tokens"
},
"total_duration_seconds": {
"type": "number",
"title": "Total Duration Seconds",
"default": 0.0
},
"total_tracking_amount": {
"type": "number",
"title": "Total Tracking Amount",
"default": 0.0
},
"request_count": { "type": "integer", "title": "Request Count" }
},
"type": "object",
"required": [
"provider",
"total_cost_microdollars",
"total_input_tokens",
"total_output_tokens",
"request_count"
],
"title": "ProviderCostSummary"
},
"ProviderEnumResponse": {
"properties": {
"provider": {
@@ -15249,39 +14938,6 @@
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserCostSummary": {
"properties": {
"user_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
},
"email": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Email"
},
"total_cost_microdollars": {
"type": "integer",
"title": "Total Cost Microdollars"
},
"total_input_tokens": {
"type": "integer",
"title": "Total Input Tokens"
},
"total_output_tokens": {
"type": "integer",
"title": "Total Output Tokens"
},
"request_count": { "type": "integer", "title": "Request Count" }
},
"type": "object",
"required": [
"total_cost_microdollars",
"total_input_tokens",
"total_output_tokens",
"request_count"
],
"title": "UserCostSummary"
},
"UserHistoryResponse": {
"properties": {
"history": {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB