Compare commits
90 Commits
dev
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcbb1613cc | ||
|
|
1fa89d9488 | ||
|
|
3e016508d4 | ||
|
|
b80d7abda9 | ||
|
|
0e310c788a | ||
|
|
91af007c18 | ||
|
|
e7ca81ed89 | ||
|
|
5164fa878f | ||
|
|
cf605ef5a3 | ||
|
|
e7bd05c6f1 | ||
|
|
22fb3549e3 | ||
|
|
1c3fe1444e | ||
|
|
b89321a688 | ||
|
|
630d6d4705 | ||
|
|
7c685c6677 | ||
|
|
bbdf13c7a8 | ||
|
|
e1ea4cf326 | ||
|
|
db6b4444e0 | ||
|
|
9b1175473b | ||
|
|
752a238166 | ||
|
|
2a73d1baa9 | ||
|
|
254e6057f4 | ||
|
|
a616e5a060 | ||
|
|
c9461836c6 | ||
|
|
50a8df3d67 | ||
|
|
3f7a8dc44d | ||
|
|
1c15d6a6cc | ||
|
|
a31be77408 | ||
|
|
1d45f2f18c | ||
|
|
27e34e9514 | ||
|
|
16d696edcc | ||
|
|
f87bbd5966 | ||
|
|
b64d1ed9fa | ||
|
|
3895d95826 | ||
|
|
181208528f | ||
|
|
0365a26c85 | ||
|
|
fb63ae54f0 | ||
|
|
6de79fb73f | ||
|
|
d57da6c078 | ||
|
|
689cd67a13 | ||
|
|
dca89d1586 | ||
|
|
2f63fcd383 | ||
|
|
f04cd08e40 | ||
|
|
44714f1b25 | ||
|
|
78b95f8a76 | ||
|
|
6f0c1dfa11 | ||
|
|
5e595231da | ||
|
|
7b36bed8a5 | ||
|
|
372900c141 | ||
|
|
7afd2b249d | ||
|
|
8d22653810 | ||
|
|
b00e16b438 | ||
|
|
b5acfb7855 | ||
|
|
1ee0bd6619 | ||
|
|
4190f75b0b | ||
|
|
71315aa982 | ||
|
|
960f893295 | ||
|
|
759effab60 | ||
|
|
45b6ada739 | ||
|
|
da544d3411 | ||
|
|
54e5059d7c | ||
|
|
1d7d2f77f3 | ||
|
|
567bc73ec4 | ||
|
|
61ef54af05 | ||
|
|
405403e6b7 | ||
|
|
ab16e63b0a | ||
|
|
45d3193727 | ||
|
|
9a08011d7d | ||
|
|
6fa66ac7da | ||
|
|
4bad08394c | ||
|
|
993c43b623 | ||
|
|
a8a62eeefc | ||
|
|
173614bcc5 | ||
|
|
fbe634fb19 | ||
|
|
a338c72c42 | ||
|
|
7f4398efa3 | ||
|
|
c2a054c511 | ||
|
|
83b00f4789 | ||
|
|
95524e94b3 | ||
|
|
2c517ff9a1 | ||
|
|
7020ae2189 | ||
|
|
b9336984be | ||
|
|
9924dedddc | ||
|
|
c054799b4f | ||
|
|
f3b5d584a3 | ||
|
|
476d9dcf80 | ||
|
|
072b623f8b | ||
|
|
26b0c95936 | ||
|
|
308357de84 | ||
|
|
1a6c50c6cc |
@@ -1,545 +0,0 @@
|
||||
---
|
||||
name: orchestrate
|
||||
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
|
||||
user-invocable: true
|
||||
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "6.0.0"
|
||||
---
|
||||
|
||||
# Orchestrate — Agent Fleet Supervisor
|
||||
|
||||
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
|
||||
|
||||
## Scripts
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
|
||||
STATE_FILE=~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
|
||||
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
|
||||
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
|
||||
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
|
||||
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
|
||||
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
|
||||
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
|
||||
| `status.sh` | Print fleet status + live pane commands |
|
||||
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
|
||||
| `classify-pane.sh WINDOW` | Classify one pane state |
|
||||
|
||||
## Supervision model
|
||||
|
||||
```
|
||||
Orchestrating Claude (this Claude session — IS the supervisor)
|
||||
└── Reads pane output, checks CI, intervenes with targeted guidance
|
||||
run-loop.sh (separate tmux window, every 30s)
|
||||
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
|
||||
```
|
||||
|
||||
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
|
||||
|
||||
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
|
||||
|
||||
## Checkpoint protocol
|
||||
|
||||
Agents output checkpoints as they complete each required step:
|
||||
|
||||
```
|
||||
CHECKPOINT:<step-name>
|
||||
```
|
||||
|
||||
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
|
||||
|
||||
## Worktree lifecycle
|
||||
|
||||
```text
|
||||
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
|
||||
↓
|
||||
CHECKPOINT:<step> (as steps complete)
|
||||
↓
|
||||
ORCHESTRATOR:DONE
|
||||
↓
|
||||
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
|
||||
↓
|
||||
state → "done", notify, window KEPT OPEN
|
||||
↓
|
||||
user/orchestrator explicitly requests recycle
|
||||
↓
|
||||
recycle-agent.sh → spare/N (free again)
|
||||
```
|
||||
|
||||
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
|
||||
|
||||
**To resume a done or crashed session:**
|
||||
```bash
|
||||
# Resume by stored session ID (preferred — exact session, full context)
|
||||
claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
|
||||
# Or resume most recent session in that worktree directory
|
||||
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
|
||||
```
|
||||
|
||||
**To manually recycle when ready:**
|
||||
```bash
|
||||
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
|
||||
# Then update state:
|
||||
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## State file (`~/.claude/orchestrator-state.json`)
|
||||
|
||||
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`).
|
||||
|
||||
```json
|
||||
{
|
||||
"active": true,
|
||||
"tmux_session": "autogpt1",
|
||||
"idle_threshold_seconds": 300,
|
||||
"loop_window": "autogpt1:5",
|
||||
"repo": "Significant-Gravitas/AutoGPT",
|
||||
"discord_webhook": "https://discord.com/api/webhooks/...",
|
||||
"last_poll_at": 0,
|
||||
"agents": [
|
||||
{
|
||||
"window": "autogpt1:3",
|
||||
"worktree": "AutoGPT6",
|
||||
"worktree_path": "/path/to/AutoGPT6",
|
||||
"spare_branch": "spare/6",
|
||||
"branch": "feat/my-feature",
|
||||
"objective": "Implement X and open a PR",
|
||||
"pr_number": "12345",
|
||||
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"steps": ["pr-address", "pr-test"],
|
||||
"checkpoints": ["pr-address"],
|
||||
"state": "running",
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": 0,
|
||||
"spawned_at": 0,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Top-level optional fields:
|
||||
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
|
||||
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
|
||||
|
||||
Per-agent fields:
|
||||
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
|
||||
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
|
||||
|
||||
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
|
||||
|
||||
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
|
||||
|
||||
## Serial /pr-test rule
|
||||
|
||||
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
|
||||
|
||||
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
|
||||
|
||||
You (the orchestrating Claude) own the test queue:
|
||||
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
|
||||
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
|
||||
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
|
||||
4. Feed results back to the relevant agent via `tmux send-keys`:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
5. Wait for CI to confirm green before marking the agent done.
|
||||
|
||||
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
|
||||
|
||||
## Session restore (tested and confirmed)
|
||||
|
||||
Agent sessions are saved to disk. To restore a closed or crashed session:
|
||||
|
||||
```bash
|
||||
# If session_id is in state (preferred):
|
||||
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
|
||||
|
||||
# If no session_id (use --continue for most recent session in that directory):
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
|
||||
```
|
||||
|
||||
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
|
||||
|
||||
```bash
|
||||
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
|
||||
'(.agents[] | select(.window == $old)).window = $new' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## Intent → action mapping
|
||||
|
||||
Match the user's message to one of these intents:
|
||||
|
||||
| The user says something like… | What to do |
|
||||
|---|---|
|
||||
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
|
||||
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
|
||||
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
|
||||
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
|
||||
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
|
||||
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
|
||||
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
|
||||
|
||||
When the intent is ambiguous, show capacity first and ask what tasks to run.
|
||||
|
||||
## Spawning agents
|
||||
|
||||
### 1. Resolve tmux session
|
||||
|
||||
```bash
|
||||
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
|
||||
```
|
||||
|
||||
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
|
||||
|
||||
### 2. Show available capacity
|
||||
|
||||
```bash
|
||||
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
|
||||
```
|
||||
|
||||
### 3. Collect tasks from the user
|
||||
|
||||
For each task, gather:
|
||||
- **objective** — what to do (e.g. "implement feature X and open a PR")
|
||||
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
|
||||
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
|
||||
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
|
||||
|
||||
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
|
||||
|
||||
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
|
||||
|
||||
### 4. Spawn one agent per task
|
||||
|
||||
```bash
|
||||
# Get ordered list of spare worktrees
|
||||
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
|
||||
|
||||
# For each task, take the next spare line:
|
||||
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
|
||||
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
|
||||
|
||||
# With PR number and required steps:
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
|
||||
|
||||
# Without PR (new work):
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
|
||||
```
|
||||
|
||||
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
|
||||
|
||||
```bash
|
||||
# Derive repo from git remote (used by verify-complete.sh + supervisor)
|
||||
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
|
||||
jq -n \
|
||||
--arg session "$SESSION" \
|
||||
--arg repo "$REPO" \
|
||||
--argjson threshold 300 \
|
||||
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
|
||||
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
|
||||
> ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
Optionally add a Discord webhook for completion notifications:
|
||||
```bash
|
||||
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
|
||||
|
||||
### 5. Start the mechanical babysitter
|
||||
|
||||
```bash
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
### 6. Begin supervising directly in this conversation
|
||||
|
||||
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
|
||||
|
||||
## Adding an agent
|
||||
|
||||
Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user.
|
||||
|
||||
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
|
||||
|
||||
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
|
||||
|
||||
### Poll loop mechanism
|
||||
|
||||
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
|
||||
|
||||
1. Start each poll with `run_in_background: true` + a sleep before the work:
|
||||
```bash
|
||||
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
|
||||
# + similar for each active window
|
||||
```
|
||||
2. When the background job notifies you, read the pane output and take action.
|
||||
3. Immediately schedule the next background poll — this keeps the loop alive.
|
||||
4. Stop scheduling when all agents are done/escalated.
|
||||
|
||||
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
|
||||
|
||||
### Each poll: what to check
|
||||
|
||||
```bash
|
||||
# 1. Read state
|
||||
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
|
||||
|
||||
# 2. For each running/stuck/idle agent, capture pane
|
||||
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
|
||||
```
|
||||
|
||||
For each agent, decide:
|
||||
|
||||
| What you see | Action |
|
||||
|---|---|
|
||||
| Spinner / tools running | Do nothing — agent is working |
|
||||
| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
|
||||
| Stuck in error loop | Send targeted fix with exact error + solution |
|
||||
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
|
||||
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
|
||||
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
|
||||
| `ORCHESTRATOR:DONE` in output | Run `verify-complete.sh` — if it fails, re-brief with specific reason |
|
||||
|
||||
### Strict ORCHESTRATOR:DONE gate
|
||||
|
||||
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
|
||||
|
||||
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=~/.claude/orchestrator/scripts
|
||||
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
|
||||
```
|
||||
|
||||
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
|
||||
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
|
||||
|
||||
### Re-brief a stalled agent
|
||||
|
||||
```bash
|
||||
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
|
||||
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
|
||||
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
|
||||
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
|
||||
|
||||
## Self-recovery protocol (agents)
|
||||
|
||||
spawn-agent.sh automatically includes this instruction in every objective:
|
||||
|
||||
> If your context compacts and you lose track of what to do, run:
|
||||
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
|
||||
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
|
||||
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
|
||||
|
||||
## Passing images and screenshots to agents
|
||||
|
||||
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
|
||||
|
||||
1. **Save the image to a temp file** with a stable path:
|
||||
```bash
|
||||
# If the user drags in a screenshot or you receive a file path:
|
||||
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
|
||||
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
|
||||
```
|
||||
|
||||
2. **Reference the path in the objective string**:
|
||||
```bash
|
||||
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
|
||||
```
|
||||
|
||||
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
|
||||
|
||||
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator final evaluation (YOU decide, not the script)
|
||||
|
||||
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
|
||||
|
||||
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
|
||||
|
||||
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
|
||||
|
||||
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
|
||||
|
||||
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
|
||||
```
|
||||
- [ ] /pr-test PR #12636 — fix copilot retry logic
|
||||
- [ ] /pr-test PR #12699 — builder chat panel
|
||||
```
|
||||
Run one at a time. Check off as you go.
|
||||
|
||||
```
|
||||
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
|
||||
```
|
||||
|
||||
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
|
||||
|
||||
```
|
||||
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
|
||||
Context: This PR implements <objective from state file>. Key files: <list>.
|
||||
Please verify: <specific behaviors to check>.
|
||||
```
|
||||
|
||||
Only one `/pr-test` at a time — they share ports and DB.
|
||||
|
||||
### /pr-test result evaluation
|
||||
|
||||
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
|
||||
|
||||
| `/pr-test` result | Action |
|
||||
|---|---|
|
||||
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
|
||||
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
|
||||
| Any headline scenario **FAIL** | Re-brief the agent immediately |
|
||||
|
||||
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
|
||||
|
||||
**When any headline scenario is PARTIAL or FAIL:**
|
||||
|
||||
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
|
||||
2. Re-brief the agent with the specific scenario that failed and what was missing:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
3. Set state back to `running`:
|
||||
```bash
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
|
||||
|
||||
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
|
||||
|
||||
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
|
||||
|
||||
### 2. Do your own evaluation
|
||||
|
||||
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
|
||||
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
|
||||
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
|
||||
4. **Check the PR description** — title, summary, test plan complete?
|
||||
|
||||
### 3. Decide
|
||||
|
||||
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
|
||||
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
|
||||
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
|
||||
|
||||
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
|
||||
|
||||
```bash
|
||||
# Mark done after your positive evaluation:
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## When to stop the fleet
|
||||
|
||||
Stop the fleet (`active = false`) when **all** of the following are true:
|
||||
|
||||
| Check | How to verify |
|
||||
|---|---|
|
||||
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
|
||||
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
|
||||
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
|
||||
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
|
||||
| No agents are `escalated` without human review | If any are escalated, surface to user first |
|
||||
|
||||
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
|
||||
|
||||
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
|
||||
|
||||
```bash
|
||||
# Graceful stop
|
||||
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
|
||||
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
|
||||
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
|
||||
```
|
||||
|
||||
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
|
||||
|
||||
## tmux send-keys pattern
|
||||
|
||||
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
|
||||
|
||||
```bash
|
||||
# CORRECT — text then Enter separately
|
||||
tmux send-keys -t "$WINDOW" "your long message here"
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# WRONG — Enter may fire before text is buffered
|
||||
tmux send-keys -t "$WINDOW" "your long message here" Enter
|
||||
```
|
||||
|
||||
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
|
||||
|
||||
---
|
||||
|
||||
## Protected worktrees
|
||||
|
||||
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
|
||||
|
||||
| Worktree | Protected branch | Why |
|
||||
|---|---|---|
|
||||
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
|
||||
|
||||
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
|
||||
|
||||
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
|
||||
|
||||
---
|
||||
|
||||
## Key rules
|
||||
|
||||
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
|
||||
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
|
||||
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
|
||||
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
|
||||
5. **Always `--permission-mode bypassPermissions`** on every spawn
|
||||
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
|
||||
7. **Atomic state writes** — always write to `.tmp` then `mv`
|
||||
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
|
||||
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
|
||||
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
|
||||
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
|
||||
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
|
||||
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
|
||||
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
|
||||
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
|
||||
#
|
||||
# Usage: capacity.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree of the current git repo.
|
||||
#
|
||||
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
|
||||
echo "=== Available (spare) worktrees ==="
|
||||
if [ -n "$REPO_ROOT" ]; then
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
|
||||
else
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$SPARE" ]; then
|
||||
echo " (none)"
|
||||
else
|
||||
while IFS= read -r line; do
|
||||
[ -z "$line" ] && continue
|
||||
echo " ✓ $line"
|
||||
done <<< "$SPARE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== In-use worktrees ==="
|
||||
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
|
||||
"$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -n "$IN_USE" ]; then
|
||||
echo "$IN_USE"
|
||||
else
|
||||
echo " (none)"
|
||||
fi
|
||||
else
|
||||
echo " (no active state file)"
|
||||
fi
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# classify-pane.sh — Classify the current state of a tmux pane
|
||||
#
|
||||
# Usage: classify-pane.sh <tmux-target>
|
||||
# tmux-target: e.g. "work:0", "work:1.0"
|
||||
#
|
||||
# Output (stdout): JSON object:
|
||||
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
|
||||
#
|
||||
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TARGET="${1:-}"
|
||||
|
||||
if [ -z "$TARGET" ]; then
|
||||
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate tmux target format: session:window or session:window.pane
|
||||
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check session exists (use %%:* to extract session name from session:window)
|
||||
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
|
||||
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the current foreground command in the pane
|
||||
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
|
||||
|
||||
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
|
||||
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
|
||||
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
|
||||
| grep -v '^[[:space:]]*$' || true)
|
||||
|
||||
# --- Check: explicit completion marker ---
|
||||
# Must be on its own line (not buried in the objective text sent at spawn time).
|
||||
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
|
||||
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Check: Claude Code approval prompt patterns ---
|
||||
LAST_40=$(echo "$CLEAN" | tail -40)
|
||||
APPROVAL_PATTERNS=(
|
||||
"Do you want to proceed"
|
||||
"Do you want to make this"
|
||||
"\\[y/n\\]"
|
||||
"\\[Y/n\\]"
|
||||
"\\[n/Y\\]"
|
||||
"Proceed\\?"
|
||||
"Allow this command"
|
||||
"Run bash command"
|
||||
"Allow bash"
|
||||
"Would you like"
|
||||
"Press enter to continue"
|
||||
"Esc to cancel"
|
||||
)
|
||||
for pattern in "${APPROVAL_PATTERNS[@]}"; do
|
||||
if echo "$LAST_40" | grep -qiE "$pattern"; then
|
||||
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
|
||||
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Check: shell prompt (claude has exited) ---
|
||||
# If the foreground process is a shell (not claude/node), the agent has exited
|
||||
case "$PANE_CMD" in
|
||||
zsh|bash|fish|sh|dash|tcsh|ksh)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
# Agent is still running (claude/node/python is the foreground process)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# find-spare.sh — list worktrees on spare/N branches (free to use)
|
||||
#
|
||||
# Usage: find-spare.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree containing the current git repo.
|
||||
#
|
||||
# Output (stdout): one line per available worktree: "PATH BRANCH"
|
||||
# e.g.: /Users/me/Code/AutoGPT3 spare/3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
if [ -z "$REPO_ROOT" ]; then
|
||||
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git -C "$REPO_ROOT" worktree list --porcelain \
|
||||
| awk '
|
||||
/^worktree / { path = substr($0, 10) }
|
||||
/^branch / { branch = substr($0, 8); print path " " branch }
|
||||
' \
|
||||
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
|
||||
| sed 's|refs/heads/||'
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# notify.sh — send a fleet notification message
|
||||
#
|
||||
# Delivery order (first available wins):
|
||||
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
|
||||
# 2. macOS notification center — osascript (silent fail if unavailable)
|
||||
# 3. Stdout only
|
||||
#
|
||||
# Usage: notify.sh MESSAGE
|
||||
# Exit: always 0 (notification failure must not abort the caller)
|
||||
|
||||
MESSAGE="${1:-}"
|
||||
[ -z "$MESSAGE" ] && exit 0
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# --- Resolve Discord webhook ---
|
||||
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
|
||||
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
|
||||
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# --- Discord delivery ---
|
||||
if [ -n "$WEBHOOK" ]; then
|
||||
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
|
||||
curl -s -X POST "$WEBHOOK" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" > /dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
# --- macOS notification center (silent if not macOS or osascript missing) ---
|
||||
if command -v osascript &>/dev/null 2>&1; then
|
||||
# Escape single quotes for AppleScript
|
||||
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
|
||||
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Always print to stdout so run-loop.sh logs it
|
||||
echo "$MESSAGE"
|
||||
exit 0
|
||||
@@ -1,257 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# poll-cycle.sh — Single orchestrator poll cycle
|
||||
#
|
||||
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
|
||||
# and outputs a JSON array of actions for Claude to take.
|
||||
#
|
||||
# Usage: poll-cycle.sh
|
||||
# Output (stdout): JSON array of action objects
|
||||
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
|
||||
# "worktree": "...", "objective": "...", "reason": "..." }]
|
||||
#
|
||||
# The state file is updated in-place (atomic write via .tmp).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
|
||||
|
||||
# Cross-platform md5: always outputs just the hex digest
|
||||
md5_hash() {
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum | awk '{print $1}'
|
||||
else
|
||||
md5 | awk '{print $NF}'
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
|
||||
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
|
||||
|
||||
# Ensure state file exists
|
||||
if [ ! -f "$STATE_FILE" ]; then
|
||||
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Validate JSON upfront before any jq reads that run under set -e.
|
||||
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
|
||||
# abort the script at the ACTIVE read below without emitting any JSON output.
|
||||
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
|
||||
if [ "$ACTIVE" != "true" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NOW=$(date +%s)
|
||||
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
|
||||
|
||||
ACTIONS="[]"
|
||||
UPDATED_AGENTS="[]"
|
||||
|
||||
# Read agents as newline-delimited JSON objects.
|
||||
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
|
||||
# so we suppress that exit code and separately validate the file is well-formed JSON.
|
||||
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
|
||||
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
fi
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "$AGENTS_JSON" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while IFS= read -r agent; do
|
||||
[ -z "$agent" ] && continue
|
||||
|
||||
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
|
||||
WINDOW=$(echo "$agent" | jq -r '.window // ""')
|
||||
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
|
||||
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
|
||||
STATE=$(echo "$agent" | jq -r '.state // "running"')
|
||||
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
|
||||
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
|
||||
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
|
||||
|
||||
# Validate window format to prevent tmux target injection.
|
||||
# Allow session:window (numeric or named) and session:window.pane
|
||||
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo "Skipping agent with invalid window value: $WINDOW" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Pass-through terminal-state agents
|
||||
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Classify pane.
|
||||
# classify-pane.sh always emits JSON before exit (even on error), so using
|
||||
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
|
||||
# Use "|| true" inside the substitution so set -euo pipefail does not abort
|
||||
# the poll cycle when classify exits with a non-zero status code.
|
||||
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
|
||||
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
|
||||
|
||||
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
|
||||
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
|
||||
|
||||
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
|
||||
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
|
||||
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
|
||||
|
||||
# --- Checkpoint tracking ---
|
||||
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
|
||||
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
|
||||
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
|
||||
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
|
||||
if [ -n "$RAW" ]; then
|
||||
FOUND_CPS=$(echo "$RAW" \
|
||||
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
|
||||
| sed 's/CHECKPOINT://' \
|
||||
| sort -u \
|
||||
| jq -R . | jq -s . 2>/dev/null || echo "[]")
|
||||
NEW_CHECKPOINTS_JSON=$(jq -n \
|
||||
--argjson existing "$EXISTING_CPS" \
|
||||
--argjson found "$FOUND_CPS" \
|
||||
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
|
||||
fi
|
||||
|
||||
# Compute content hash for stuck-detection (only for running agents)
|
||||
CURRENT_HASH=""
|
||||
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
|
||||
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
|
||||
fi
|
||||
|
||||
NEW_STATE="$STATE"
|
||||
NEW_IDLE_SINCE="$IDLE_SINCE"
|
||||
NEW_REVISION_COUNT="$REVISION_COUNT"
|
||||
ACTION="none"
|
||||
REASON="$PANE_REASON"
|
||||
|
||||
case "$PANE_STATE" in
|
||||
complete)
|
||||
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
|
||||
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
|
||||
NEW_STATE="pending_evaluation"
|
||||
ACTION="complete" # run-loop logs it but takes no action
|
||||
;;
|
||||
waiting_approval)
|
||||
NEW_STATE="waiting_approval"
|
||||
ACTION="approve"
|
||||
;;
|
||||
idle)
|
||||
# Agent process has exited — needs restart
|
||||
NEW_STATE="idle"
|
||||
ACTION="kick"
|
||||
REASON="agent exited (shell is foreground)"
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
fi
|
||||
;;
|
||||
running)
|
||||
# Clear idle_since only when transitioning from idle (agent was kicked and
|
||||
# restarted). Do NOT reset for stuck — idle_since must persist across polls
|
||||
# so STUCK_DURATION can accumulate and trigger escalation.
|
||||
# Also update the local IDLE_SINCE so the hash-stability check below uses
|
||||
# the reset value on this same poll, not the stale kick timestamp.
|
||||
if [[ "$STATE" == "idle" ]]; then
|
||||
NEW_IDLE_SINCE=0
|
||||
IDLE_SINCE=0
|
||||
fi
|
||||
# Check if hash has been stable (agent may be stuck mid-task)
|
||||
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
|
||||
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
else
|
||||
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
|
||||
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
else
|
||||
NEW_STATE="stuck"
|
||||
ACTION="kick"
|
||||
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# Only reset the idle timer when we have a valid hash comparison (pane
|
||||
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
|
||||
# preserve existing timers so stuck detection is not inadvertently reset.
|
||||
if [ -n "$CURRENT_HASH" ]; then
|
||||
NEW_STATE="running"
|
||||
NEW_IDLE_SINCE=0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
error)
|
||||
REASON="classify error: $PANE_REASON"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build updated agent record (ensure idle_since and revision_count are numeric)
|
||||
# Use || true on each jq call so a malformed field skips this agent rather than
|
||||
# aborting the entire poll cycle under set -e.
|
||||
UPDATED_AGENT=$(echo "$agent" | jq \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg hash "$CURRENT_HASH" \
|
||||
--argjson now "$NOW" \
|
||||
--arg idle_since "$NEW_IDLE_SINCE" \
|
||||
--arg revision_count "$NEW_REVISION_COUNT" \
|
||||
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
|
||||
'.state = $state
|
||||
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
|
||||
| .last_seen_at = $now
|
||||
| .idle_since = ($idle_since | tonumber)
|
||||
| .revision_count = ($revision_count | tonumber)
|
||||
| .checkpoints = $checkpoints' 2>/dev/null) || {
|
||||
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
}
|
||||
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
|
||||
|
||||
# Add action if needed
|
||||
if [ "$ACTION" != "none" ]; then
|
||||
ACTION_OBJ=$(jq -n \
|
||||
--arg window "$WINDOW" \
|
||||
--arg action "$ACTION" \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg worktree "$WORKTREE" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg reason "$REASON" \
|
||||
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
|
||||
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
|
||||
fi
|
||||
|
||||
done <<< "$AGENTS_JSON"
|
||||
|
||||
# Atomic state file update
|
||||
jq --argjson agents "$UPDATED_AGENTS" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents = $agents | .last_poll_at = $now' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
|
||||
echo "$ACTIONS"
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
|
||||
#
|
||||
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
|
||||
# WINDOW — tmux target, e.g. autogpt1:3
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — branch to restore, e.g. spare/6
|
||||
#
|
||||
# Stdout: one status line
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WINDOW="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
|
||||
# Kill the tmux window (ignore error — may already be gone)
|
||||
tmux kill-window -t "$WINDOW" 2>/dev/null || true
|
||||
|
||||
# Restore to spare branch: abort any in-progress operation, then clean
|
||||
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
|
||||
|
||||
echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)"
|
||||
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
|
||||
#
|
||||
# Handles ONLY two things that need no intelligence:
|
||||
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
|
||||
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
|
||||
#
|
||||
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
|
||||
# marking done, deciding to close windows — is the orchestrating Claude's job.
|
||||
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
|
||||
# the orchestrator's background poll loop handles it from there.
|
||||
#
|
||||
# Usage: run-loop.sh
|
||||
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Copy scripts to a stable location outside the repo so they survive branch
|
||||
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
|
||||
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
|
||||
# current branch).
|
||||
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
|
||||
mkdir -p "$STABLE_SCRIPTS_DIR"
|
||||
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
|
||||
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
|
||||
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
POLL_INTERVAL="${POLL_INTERVAL:-30}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_state WINDOW FIELD VALUE
|
||||
# ---------------------------------------------------------------------------
|
||||
update_state() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --arg v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
update_state_int() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
agent_field() {
|
||||
jq -r --arg w "$1" --arg f "$2" \
|
||||
'.agents[] | select(.window == $w) | .[$f] // ""' \
|
||||
"$STATE_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_prompt() {
|
||||
local window="$1"
|
||||
for i in $(seq 1 60); do
|
||||
local cmd pane
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter; sleep 2; continue
|
||||
fi
|
||||
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0
|
||||
sleep 1
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_kick() {
|
||||
local window="$1" state="$2"
|
||||
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
|
||||
|
||||
local worktree_path session_id
|
||||
worktree_path=$(agent_field "$window" "worktree_path")
|
||||
session_id=$(agent_field "$window" "session_id")
|
||||
|
||||
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
|
||||
|
||||
# Resume the exact session so the agent retains full context — no need to re-send objective
|
||||
if [ -n "$session_id" ]; then
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
|
||||
else
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
|
||||
fi
|
||||
|
||||
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_approve WINDOW — auto-approve dialogs that need no judgment
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_approve() {
|
||||
local window="$1"
|
||||
local pane_tail
|
||||
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
|
||||
|
||||
# Settings error dialog at startup
|
||||
if echo "$pane_tail" | grep -q "Enter to confirm"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Numbered-option dialog (e.g. "Do you want to make this edit?")
|
||||
# ❯ is already on option 1 (Yes) — Enter confirms it
|
||||
if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
|
||||
tmux send-keys -t "$window" "" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# y/n prompt for safe operations
|
||||
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
|
||||
tmux send-keys -t "$window" "y" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Anything else — supervisor handles it, just log
|
||||
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll every ${POLL_INTERVAL}s)"
|
||||
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
|
||||
echo "---"
|
||||
|
||||
while true; do
|
||||
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "[$(date +%H:%M:%S)] active=false — exiting."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
|
||||
KICKED=0; DONE=0
|
||||
|
||||
while IFS= read -r action; do
|
||||
[ -z "$action" ] && continue
|
||||
WINDOW=$(echo "$action" | jq -r '.window // ""')
|
||||
ACTION=$(echo "$action" | jq -r '.action // ""')
|
||||
STATE=$(echo "$action" | jq -r '.state // ""')
|
||||
|
||||
case "$ACTION" in
|
||||
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
|
||||
approve) handle_approve "$WINDOW" || true ;;
|
||||
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
|
||||
esac
|
||||
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
|
||||
|
||||
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
|
||||
"$STATE_FILE" 2>/dev/null || echo 0)
|
||||
|
||||
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled"
|
||||
sleep "$POLL_INTERVAL"
|
||||
done
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
|
||||
#
|
||||
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
|
||||
# SESSION — tmux session name, e.g. autogpt1
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
|
||||
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
|
||||
# OBJECTIVE — task description sent to the agent
|
||||
# PR_NUMBER — (optional) GitHub PR number for completion verification
|
||||
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
|
||||
#
|
||||
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
|
||||
# Exit non-zero on failure.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SESSION="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
NEW_BRANCH="$4"
|
||||
OBJECTIVE="$5"
|
||||
PR_NUMBER="${6:-}"
|
||||
STEPS=("${@:7}")
|
||||
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# Generate a stable session ID so this agent's Claude session can always be resumed:
|
||||
# claude --resume $SESSION_ID --permission-mode bypassPermissions
|
||||
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
|
||||
|
||||
# Create (or switch to) the task branch
|
||||
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|
||||
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
|
||||
|
||||
# Open a new named tmux window; capture its numeric index
|
||||
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
|
||||
WINDOW="${SESSION}:${WIN_IDX}"
|
||||
|
||||
# Append the initial agent record to the state file so subsequent jq updates find it.
|
||||
# This must happen before the pr_number/steps update below.
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
NOW=$(date +%s)
|
||||
jq --arg window "$WINDOW" \
|
||||
--arg worktree "$WORKTREE_NAME" \
|
||||
--arg worktree_path "$WORKTREE_PATH" \
|
||||
--arg spare_branch "$SPARE_BRANCH" \
|
||||
--arg branch "$NEW_BRANCH" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg session_id "$SESSION_ID" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents += [{
|
||||
"window": $window,
|
||||
"worktree": $worktree,
|
||||
"worktree_path": $worktree_path,
|
||||
"spare_branch": $spare_branch,
|
||||
"branch": $branch,
|
||||
"objective": $objective,
|
||||
"session_id": $session_id,
|
||||
"state": "running",
|
||||
"checkpoints": [],
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": $now,
|
||||
"spawned_at": $now,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
|
||||
# The agent record was appended above so the jq select now finds it.
|
||||
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
|
||||
if [ "${#STEPS[@]}" -gt 0 ]; then
|
||||
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
|
||||
else
|
||||
STEPS_JSON='[]'
|
||||
fi
|
||||
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
|
||||
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Launch claude with a stable session ID so it can always be resumed after a crash:
|
||||
# claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
|
||||
|
||||
# Wait up to 60s for claude to be fully interactive:
|
||||
# both pane_current_command == 'node' AND the '❯' prompt is visible.
|
||||
PROMPT_FOUND=false
|
||||
for i in $(seq 1 60); do
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
PANE=$(tmux capture-pane -t "$WINDOW" -p 2>/dev/null || echo "")
|
||||
if echo "$PANE" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$WINDOW" Down Enter
|
||||
sleep 2
|
||||
continue
|
||||
fi
|
||||
if [[ "$CMD" == "node" ]] && echo "$PANE" | grep -q "❯"; then
|
||||
PROMPT_FOUND=true
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
if ! $PROMPT_FOUND; then
|
||||
echo "[spawn-agent] WARNING: timed out waiting for ❯ prompt on $WINDOW — sending objective anyway" >&2
|
||||
fi
|
||||
|
||||
# Send the task. Split text and Enter — if combined, Enter can fire before the string
|
||||
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
|
||||
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# Only output the window address — nothing else (callers parse this)
|
||||
echo "$WINDOW"
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# status.sh — print orchestrator status: state file summary + live tmux pane commands
|
||||
#
|
||||
# Usage: status.sh
|
||||
# Reads: ~/.claude/orchestrator-state.json
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "No orchestrator state found at $STATE_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Header: active status, session, thresholds, last poll
|
||||
jq -r '
|
||||
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
|
||||
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
|
||||
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
|
||||
""
|
||||
' "$STATE_FILE"
|
||||
|
||||
# Each agent: state, window, worktree/branch, truncated objective
|
||||
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
|
||||
if [ "$AGENT_COUNT" -eq 0 ]; then
|
||||
echo " (no agents registered)"
|
||||
else
|
||||
jq -r '
|
||||
.agents[] |
|
||||
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
|
||||
" \(.objective // "" | .[0:70])"
|
||||
' "$STATE_FILE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# Live pane_current_command for non-done agents
|
||||
while IFS= read -r WINDOW; do
|
||||
[ -z "$WINDOW" ] && continue
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
|
||||
echo " $WINDOW live: $CMD"
|
||||
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)
|
||||
@@ -1,180 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# verify-complete.sh — verify a PR task is truly done before marking the agent done
|
||||
#
|
||||
# Check order matters:
|
||||
# 1. Checkpoints — did the agent do all required steps?
|
||||
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
|
||||
# 3. CI passing — no failures (agent must fix before done)
|
||||
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
|
||||
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
|
||||
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
|
||||
#
|
||||
# Usage: verify-complete.sh WINDOW
|
||||
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WINDOW="$1"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
|
||||
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
|
||||
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
|
||||
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
|
||||
|
||||
# No PR number = cannot verify
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 1: all required steps are checkpointed ---
|
||||
MISSING=""
|
||||
while IFS= read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
|
||||
MISSING="$MISSING $step"
|
||||
fi
|
||||
done <<< "$STEPS"
|
||||
|
||||
if [ -n "$MISSING" ]; then
|
||||
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve repo for all GitHub checks below
|
||||
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
|
||||
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
|
||||
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$REPO" ]; then
|
||||
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
|
||||
|
||||
# --- Check 2: CI fully complete — no pending checks ---
|
||||
# Pending checks MUST finish before we check threads/reviews:
|
||||
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
|
||||
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$PENDING" -gt 0 ]; then
|
||||
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 3: CI passing — no failures ---
|
||||
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$FAILING" -gt 0 ]; then
|
||||
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
|
||||
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
|
||||
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
|
||||
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
|
||||
if [ -n "$LATEST_RUN_AT" ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
|
||||
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
OWNER=$(echo "$REPO" | cut -d/ -f1)
|
||||
REPONAME=$(echo "$REPO" | cut -d/ -f2)
|
||||
|
||||
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
|
||||
UNRESOLVED=$(gh api graphql -f query="
|
||||
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
|
||||
pullRequest(number: ${PR_NUMBER}) {
|
||||
reviewThreads(first: 50) { nodes { isResolved } }
|
||||
}
|
||||
}
|
||||
}
|
||||
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
|
||||
|
||||
if [ "$UNRESOLVED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
|
||||
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
|
||||
# Stale reviews (pre-dating the fixing commits) should not block verification.
|
||||
#
|
||||
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
|
||||
# treat that as NOT COMPLETE rather than silently passing.
|
||||
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
|
||||
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
|
||||
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
|
||||
# PR activity (bot comments, label changes) which would create false negatives.
|
||||
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
|
||||
--json commits,latestReviews 2>/dev/null) || {
|
||||
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
|
||||
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
|
||||
|
||||
BLOCKING_CHANGES_REQUESTED=0
|
||||
BLOCKING_REQUESTERS=""
|
||||
|
||||
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
|
||||
while IFS= read -r review; do
|
||||
[ -z "$review" ] && continue
|
||||
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
|
||||
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
|
||||
if [ -z "$REVIEW_DATE" ]; then
|
||||
# No submission date — treat as fresh (conservative: blocks verification)
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
else
|
||||
if date --version >/dev/null 2>&1; then
|
||||
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
|
||||
# Review was submitted AFTER latest commit — still fresh, blocks verification
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
fi
|
||||
# Review submitted BEFORE latest commit — stale, skip
|
||||
fi
|
||||
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
|
||||
else
|
||||
# No commit date or no changes_requested — check raw count as fallback
|
||||
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
|
||||
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
|
||||
exit 0
|
||||
@@ -90,12 +90,10 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
Use a **markdown commit link** so GitHub renders it as a clickable reference. Get the full SHA with `git rev-parse HEAD` after committing:
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
|
||||
@@ -530,9 +530,19 @@ After showing all screenshots, output a **detailed** summary table:
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
# Each explanation MUST answer three things:
|
||||
# 1. FLOW: Which test scenario / user journey is this part of?
|
||||
# 2. STEPS: What exact actions were taken to reach this state?
|
||||
# 3. EVIDENCE: What does this screenshot prove (pass/fail/data)?
|
||||
#
|
||||
# Good example:
|
||||
# ["03-cost-log-after-run.png"]="Flow: LLM block cost tracking. Steps: Logged in as tester@gmail.com → ran 'Cost Test Agent' → waited for COMPLETED status. Evidence: PlatformCostLog table shows 1 new row with cost_microdollars=1234 and correct user_id."
|
||||
#
|
||||
# Bad example (too vague — never do this):
|
||||
# ["03-cost-log.png"]="Shows the cost log table."
|
||||
["01-login-page.png"]="Flow: Login flow. Steps: Opened /login. Evidence: Login page renders with email/password fields and SSO options visible."
|
||||
["02-builder-with-block.png"]="Flow: Block execution. Steps: Logged in → /build → added LLM block. Evidence: Builder canvas shows block connected to trigger, ready to run."
|
||||
# ... one entry per screenshot using the flow/steps/evidence format above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
@@ -547,7 +557,8 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
@@ -584,11 +595,11 @@ for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
# Step 2: Create tree, commit (with parent), and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
|
||||
# Resolve parent commit so screenshots are chained, not orphan root commits
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
|
||||
# Resolve existing branch tip as parent (avoids orphan commits on repeat runs)
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || true)
|
||||
if [ -n "$PARENT_SHA" ]; then
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
@@ -596,6 +607,7 @@ if [ -n "$PARENT_SHA" ]; then
|
||||
-f "parents[]=$PARENT_SHA" \
|
||||
--jq '.sha')
|
||||
else
|
||||
# First commit on this branch — no parent
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
@@ -606,7 +618,7 @@ gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -F force=true
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
@@ -670,122 +682,122 @@ ${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
POSTED_BODY=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" --jq '.body')
|
||||
rm -f "$COMMENT_FILE"
|
||||
|
||||
# Verify the posted comment contains inline images — exit 1 if none found
|
||||
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
|
||||
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
|
||||
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
|
||||
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Inline images verified in posted comment"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
3. A structured explanation below each screenshot covering: **Flow** (which scenario), **Steps** (exact actions taken to reach this state), **Evidence** (what this proves — pass/fail/data values). A bare "shows the page" caption is not acceptable.
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Step 8: Evaluate and post a formal PR review
|
||||
|
||||
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
|
||||
|
||||
### Evaluation criteria
|
||||
|
||||
Re-read the PR description:
|
||||
```bash
|
||||
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
|
||||
```
|
||||
|
||||
Score the run against each criterion:
|
||||
|
||||
| Criterion | Pass condition |
|
||||
|-----------|---------------|
|
||||
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
|
||||
| **All scenarios pass** | No FAIL rows in the results table |
|
||||
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
|
||||
| **Before/after evidence** | Every state-changing API call has before/after values logged |
|
||||
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
|
||||
| **No regressions** | Existing core flows (login, agent create/run) still work |
|
||||
|
||||
### Decision logic
|
||||
|
||||
```
|
||||
ALL criteria pass → APPROVE
|
||||
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
|
||||
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
|
||||
```
|
||||
|
||||
### Post the review
|
||||
**Verify inline rendering after posting — this is required, not optional:**
|
||||
|
||||
```bash
|
||||
REVIEW_FILE=$(mktemp)
|
||||
# 1. Confirm the posted comment body contains inline image markdown syntax
|
||||
if ! echo "$POSTED_BODY" | grep -q '!\['; then
|
||||
echo "❌ FAIL: No inline image tags in posted comment body. Re-check IMAGE_MARKDOWN and re-post."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Count results
|
||||
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
|
||||
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
|
||||
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
|
||||
|
||||
# List any coverage gaps found during evaluation (populate this array as you assess)
|
||||
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
|
||||
COVERAGE_GAPS=()
|
||||
# 2. Verify at least one raw URL actually resolves (catches wrong branch name, wrong path, etc.)
|
||||
FIRST_IMG_URL=$(echo "$POSTED_BODY" | grep -o 'https://raw.githubusercontent.com[^)]*' | head -1)
|
||||
if [ -n "$FIRST_IMG_URL" ]; then
|
||||
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" --max-time 10 "$FIRST_IMG_URL")
|
||||
if [ "$HTTP_STATUS" = "200" ]; then
|
||||
echo "✅ Inline images confirmed and raw URL resolves (HTTP 200)"
|
||||
else
|
||||
echo "❌ FAIL: Raw image URL returned HTTP $HTTP_STATUS — images will not render inline."
|
||||
echo " URL: $FIRST_IMG_URL"
|
||||
echo " Check branch name, path, and that the push succeeded."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "⚠️ Could not extract a raw URL from the comment — verify manually."
|
||||
fi
|
||||
```
|
||||
|
||||
**If APPROVING** — all criteria met, zero failures, full coverage:
|
||||
## Step 8: Evaluate test completeness and post a GitHub review
|
||||
|
||||
After posting the PR comment, evaluate whether the test run actually covered everything it needed to. This is NOT a rubber-stamp — be critical. Then post a formal GitHub review so the PR author and reviewers can see the verdict.
|
||||
|
||||
### 8a. Evaluate against the test plan
|
||||
|
||||
Re-read `$RESULTS_DIR/test-plan.md` (written in Step 2) and `$RESULTS_DIR/test-report.md` (written in Step 5). For each scenario in the plan, answer:
|
||||
|
||||
> **Note:** `test-report.md` is written in Step 5. If it doesn't exist, write it before proceeding here — see the Step 5 template. Do not skip evaluation because the file is missing; create it from your notes instead.
|
||||
|
||||
| Question | Pass criteria |
|
||||
|----------|--------------|
|
||||
| Was it tested? | Explicit steps were executed, not just described |
|
||||
| Is there screenshot evidence? | At least one before/after screenshot per scenario |
|
||||
| Did the core feature work correctly? | Expected state matches actual state |
|
||||
| Were negative cases tested? | At least one failure/rejection case per feature |
|
||||
| Was DB/API state verified (not just UI)? | Raw API response or DB query confirms state change |
|
||||
|
||||
Build a verdict:
|
||||
- **APPROVE** — every scenario tested, evidence present, no bugs found or all bugs are minor/known
|
||||
- **REQUEST_CHANGES** — one or more: untested scenarios, missing evidence, bugs found, data not verified
|
||||
|
||||
### 8b. Post the GitHub review
|
||||
|
||||
```bash
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — APPROVED
|
||||
EVAL_FILE=$(mktemp)
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
|
||||
# === STEP A: Write header ===
|
||||
cat > "$EVAL_FILE" << 'ENDEVAL'
|
||||
## 🧪 Test Evaluation
|
||||
|
||||
**Coverage:** All features described in the PR were exercised.
|
||||
### Coverage checklist
|
||||
ENDEVAL
|
||||
|
||||
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
|
||||
# === STEP B: Append ONE line per scenario — do this BEFORE calculating verdict ===
|
||||
# Format: "- ✅ **Scenario N – name**: <what was done and verified>"
|
||||
# or "- ❌ **Scenario N – name**: <what is missing or broken>"
|
||||
# Examples:
|
||||
# echo "- ✅ **Scenario 1 – Login flow**: tested, screenshot evidence present, auth token verified via API" >> "$EVAL_FILE"
|
||||
# echo "- ❌ **Scenario 3 – Cost logging**: NOT verified in DB — UI showed entry but raw SQL query was skipped" >> "$EVAL_FILE"
|
||||
#
|
||||
# !!! IMPORTANT: append ALL scenario lines here before proceeding to STEP C !!!
|
||||
|
||||
**Negative tests:** Failure paths tested for each feature.
|
||||
# === STEP C: Derive verdict from the checklist — runs AFTER all lines are appended ===
|
||||
FAIL_COUNT=$(grep -c "^- ❌" "$EVAL_FILE" || true)
|
||||
if [ "$FAIL_COUNT" -eq 0 ]; then
|
||||
VERDICT="APPROVE"
|
||||
else
|
||||
VERDICT="REQUEST_CHANGES"
|
||||
fi
|
||||
|
||||
No regressions observed on core flows.
|
||||
REVIEWEOF
|
||||
# === STEP D: Append verdict section ===
|
||||
cat >> "$EVAL_FILE" << ENDVERDICT
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
|
||||
echo "✅ PR approved"
|
||||
```
|
||||
### Verdict
|
||||
ENDVERDICT
|
||||
|
||||
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
|
||||
if [ "$VERDICT" = "APPROVE" ]; then
|
||||
echo "✅ All scenarios covered with evidence. No blocking issues found." >> "$EVAL_FILE"
|
||||
else
|
||||
echo "❌ $FAIL_COUNT scenario(s) incomplete or have confirmed bugs. See ❌ items above." >> "$EVAL_FILE"
|
||||
echo "" >> "$EVAL_FILE"
|
||||
echo "**Required before merge:** address each ❌ item above." >> "$EVAL_FILE"
|
||||
fi
|
||||
|
||||
```bash
|
||||
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
|
||||
# === STEP E: Post the review ===
|
||||
gh api "repos/${REPO}/pulls/$PR_NUMBER/reviews" \
|
||||
--method POST \
|
||||
-f body="$(cat "$EVAL_FILE")" \
|
||||
-f event="$VERDICT"
|
||||
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — Changes Requested
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
|
||||
|
||||
### Required before merge
|
||||
|
||||
${FAIL_LIST}
|
||||
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
|
||||
|
||||
Please fix the above and re-run the E2E tests.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
|
||||
echo "❌ Changes requested"
|
||||
```
|
||||
|
||||
```bash
|
||||
rm -f "$REVIEW_FILE"
|
||||
rm -f "$EVAL_FILE"
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
|
||||
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
|
||||
- Never request changes for issues already fixed in this run
|
||||
- Never auto-approve without checking every scenario in the test plan
|
||||
- `REQUEST_CHANGES` if ANY scenario is untested, lacks DB/API evidence, or has a confirmed bug
|
||||
- The evaluation body must list every scenario explicitly (✅ or ❌) — not just the failures
|
||||
- If you find new bugs during evaluation, add them to the request-changes body and (if `--fix` flag is set) fix them before posting
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache dashboard results for 30 seconds per unique filter combination.
|
||||
# The table is append-only so stale reads are acceptable for analytics.
|
||||
_DASHBOARD_CACHE_TTL = 30
|
||||
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
|
||||
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
cache_key = (start, end, provider, user_id)
|
||||
cached = _dashboard_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
result = await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
_dashboard_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import PlatformCostDashboard
|
||||
|
||||
from . import platform_cost_routes
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
# Clear TTL cache so each test starts cold.
|
||||
platform_cost_routes._dashboard_cache.clear()
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_cache_hit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Second identical request returns cached result without calling the DB again."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mock_fn = mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
client.get("/platform-costs/dashboard")
|
||||
client.get("/platform-costs/dashboard")
|
||||
|
||||
mock_fn.assert_awaited_once() # second request hit the cache
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -0,0 +1,712 @@
|
||||
"""Unit tests for merge_stats cost tracking in individual blocks.
|
||||
|
||||
Covers the exa code_context, exa contents, and apollo organization blocks
|
||||
to verify provider cost is correctly extracted and reported.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_EXA_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_EXA_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_EXA_CREDENTIALS.provider,
|
||||
"id": TEST_EXA_CREDENTIALS.id,
|
||||
"type": TEST_EXA_CREDENTIALS.type,
|
||||
"title": TEST_EXA_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaCodeContextBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_float_cost(self):
|
||||
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "how to use hooks",
|
||||
"response": "Here are some examples...",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="how to use hooks",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_dollars_does_not_raise(self):
|
||||
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "query",
|
||||
"response": "response",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
merge_calls: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merge_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_is_tracked(self):
|
||||
"""A zero cost_dollars string '0.0' should still be recorded."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "query",
|
||||
"response": "...",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaContentsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_cost_dollars_total(self):
|
||||
"""provider_cost equals response.cost_dollars.total when present."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = cost_dollars
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_cost_dollars_absent(self):
|
||||
"""When response.cost_dollars is None, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = None
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchOrganizationsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_org_count(self):
|
||||
"""provider_cost == number of returned organizations, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Organization
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
|
||||
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_orgs,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(3.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_org_list_tracks_zero(self):
|
||||
"""An empty organization list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JinaEmbeddingBlock — token count from usage.total_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJinaEmbeddingBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_token_count(self):
|
||||
"""provider token count is recorded when API returns usage.total_tokens."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
"usage": {"total_tokens": 42},
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello world"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].input_token_count == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_usage_absent(self):
|
||||
"""When API response omits usage field, merge_stats is not called."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UnrealTextToSpeechBlock — character count from input text length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
test_text = "Hello, world!"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text=test_text,
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text="",
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GoogleMapsSearchBlock — item count from search_places results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleMapsSearchBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_place_count(self):
|
||||
"""provider_cost equals number of returned places, type == 'items'."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
|
||||
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=fake_places,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="coffee shops",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 4.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_tracks_zero(self):
|
||||
"""Zero places returned results in provider_cost=0.0."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="nothing here",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SmartLeadAddLeadsBlock — item count from lead_list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartLeadAddLeadsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_lead_count(self):
|
||||
"""provider_cost equals number of leads uploaded, type == 'items'."""
|
||||
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsToCampaignResponse,
|
||||
LeadInput,
|
||||
)
|
||||
|
||||
block = AddLeadToCampaignBlock()
|
||||
|
||||
fake_leads = [
|
||||
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
|
||||
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
|
||||
]
|
||||
fake_response = AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=2,
|
||||
total_leads=2,
|
||||
block_count=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
invalid_emails=[],
|
||||
already_added_to_campaign=0,
|
||||
unsubscribed_leads=[],
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
bounce_count=0,
|
||||
)
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
AddLeadToCampaignBlock,
|
||||
"add_leads_to_campaign",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = AddLeadToCampaignBlock.Input(
|
||||
campaign_id=123,
|
||||
lead_list=fake_leads,
|
||||
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=SL_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 2.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPeopleBlock — item count from people list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchPeopleBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_people_count(self):
|
||||
"""provider_cost equals number of returned people, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_people,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(5.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_people_list_tracks_zero(self):
|
||||
"""An empty people list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
"""Tests for cost tracking in Exa blocks.
|
||||
|
||||
Covers the cost_dollars → provider_cost → merge_stats path for both
|
||||
ExaContentsBlock and ExaCodeContextBlock.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
class TestExaCodeContextCostTracking:
|
||||
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_cost_string_is_parsed_and_merged(self):
|
||||
"""A numeric cost string like '0.005' is merged as provider_cost."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "test query",
|
||||
"response": "some code",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
assert any(k == "cost_dollars" for k, _ in outputs)
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_string_does_not_raise(self):
|
||||
"""A non-numeric cost_dollars value is swallowed silently."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "test",
|
||||
"response": "code",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
# No merge_stats call because float() raised ValueError
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_string_is_merged(self):
|
||||
"""'0.0' is a valid cost — should still be tracked."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "free query",
|
||||
"response": "result",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaContentsCostTracking:
|
||||
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_dollars_is_merged(self):
|
||||
"""A total of 0.0 (free tier) should still be merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCreateResearchBlock — cost_dollars from completed poll response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMPLETED_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "completed",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
"finishedAt": 1700000060000,
|
||||
"costDollars": {
|
||||
"total": 0.05,
|
||||
"numSearches": 3,
|
||||
"numPages": 10,
|
||||
"reasoningTokens": 500,
|
||||
},
|
||||
"output": {"content": "Research findings...", "parsed": None},
|
||||
}
|
||||
|
||||
PENDING_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "pending",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestExaCreateResearchBlockCostTracking:
|
||||
"""ExaCreateResearchBlock merges cost from completed poll response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total when poll returns completed."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed response has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaGetResearchBlock — cost_dollars from single GET response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaGetResearchBlockCostTracking:
|
||||
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_from_completed_research(self):
|
||||
"""merge_stats called with provider_cost=total when research has costDollars."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaWaitForResearchBlock — cost_dollars from polling response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaWaitForResearchBlockCostTracking:
|
||||
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total once polling returns completed."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(places)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -13,6 +13,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -737,6 +738,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -771,6 +773,32 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
|
||||
attribute. We use try/except AttributeError so that if the SDK ever drops
|
||||
or renames that attribute, the warning is visible in logs rather than
|
||||
silently degrading to no cost tracking.
|
||||
"""
|
||||
try:
|
||||
raw_resp = response._response # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"OpenAI SDK response missing _response attribute"
|
||||
" — OpenRouter cost tracking unavailable"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if not cost_header:
|
||||
return None
|
||||
return float(cost_header)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
@@ -1103,6 +1131,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=extract_openrouter_cost(response),
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1410,6 +1439,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
last_attempt_cost: float | None = None
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
@@ -1427,12 +1457,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
# Merge token counts for every attempt (each call costs tokens).
|
||||
# provider_cost (actual USD) is tracked separately and only merged
|
||||
# on success to avoid double-counting across retries.
|
||||
token_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
self.merge_stats(token_stats)
|
||||
last_attempt_cost = llm_response.provider_cost
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
@@ -1501,6 +1534,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1521,6 +1555,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.lead_list)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
|
||||
@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_uses_last_attempt_only(self):
|
||||
"""provider_cost is only merged from the final successful attempt.
|
||||
|
||||
Intermediate retry costs are intentionally dropped to avoid
|
||||
double-counting: the cost of failed attempts is captured in
|
||||
last_attempt_cost only when the loop eventually succeeds.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Only the final successful attempt's cost is merged
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
@@ -987,3 +1047,51 @@ class TestLlmModelMissing:
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
|
||||
class TestExtractOpenRouterCost:
|
||||
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
|
||||
|
||||
def _mk_response(self, headers: dict | None):
|
||||
response = MagicMock()
|
||||
if headers is None:
|
||||
response._response = None
|
||||
else:
|
||||
raw = MagicMock()
|
||||
raw.headers = headers
|
||||
response._response = raw
|
||||
return response
|
||||
|
||||
def test_extracts_numeric_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "0.0042"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0042
|
||||
|
||||
def test_returns_none_when_header_missing(self):
|
||||
response = self._mk_response({})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_empty_string(self):
|
||||
response = self._mk_response({"x-total-cost": ""})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_non_numeric(self):
|
||||
response = self._mk_response({"x-total-cost": "not-a-number"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_no_response_attr(self):
|
||||
response = MagicMock(spec=[]) # no _response attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_is_none(self):
|
||||
response = self._mk_response(None)
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_has_no_headers(self):
|
||||
response = MagicMock()
|
||||
response._response = MagicMock(spec=[]) # no headers attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_zero_for_zero_cost(self):
|
||||
"""Zero-cost is a valid value (free tier) and must not become None."""
|
||||
response = self._mk_response({"x-total-cost": "0"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.text)),
|
||||
provider_cost_type="characters",
|
||||
)
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -22,6 +23,7 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
from opentelemetry import trace as otel_trace
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.context import get_workspace_manager, set_execution_context
|
||||
@@ -334,6 +336,7 @@ class _BaselineStreamState:
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
cost_usd: float | None = None
|
||||
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
|
||||
session_messages: list[ChatMessage] = field(default_factory=list)
|
||||
|
||||
@@ -354,6 +357,7 @@ async def _baseline_llm_caller(
|
||||
state.thinking_stripper = _ThinkingStripper()
|
||||
|
||||
round_text = ""
|
||||
response = None # initialized before try so finally block can access it
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
@@ -430,6 +434,18 @@ async def _baseline_llm_caller(
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Extract OpenRouter cost from response headers (in finally so we
|
||||
# capture cost even when the stream errors mid-way — we already paid).
|
||||
# Accumulate across multi-round tool-calling turns.
|
||||
try:
|
||||
cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined]
|
||||
if cost_header:
|
||||
cost = float(cost_header)
|
||||
if math.isfinite(cost):
|
||||
state.cost_usd = (state.cost_usd or 0.0) + max(0.0, cost)
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
@@ -1183,8 +1199,22 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
# Set cost attributes on OTEL span before closing
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.prompt_tokens", state.turn_prompt_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens",
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
if state.cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd)
|
||||
except Exception:
|
||||
logger.debug("[Baseline] Failed to set OTEL cost attributes")
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
@@ -1226,6 +1256,8 @@ async def stream_chat_completion_baseline(
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
cost_usd=state.cost_usd,
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
# Persist structured tool-call history (assistant + tool messages)
|
||||
|
||||
@@ -4,7 +4,7 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
@@ -631,3 +631,169 @@ class TestPrepareBaselineAttachments:
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
|
||||
class TestBaselineCostExtraction:
|
||||
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_extracted_from_response_header(self):
|
||||
"""state.cost_usd is set from x-total-cost header when present."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
# Build a mock raw httpx response with the cost header
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {"x-total-cost": "0.0123"}
|
||||
|
||||
# Build a mock async streaming response that yields no chunks but has
|
||||
# a _response attribute pointing to the mock httpx response
|
||||
mock_stream_response = MagicMock()
|
||||
mock_stream_response._response = mock_raw_response
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream_response.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_stream_response
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.0123)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_accumulates_across_calls(self):
|
||||
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
def make_stream_mock(cost: str) -> MagicMock:
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": cost}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
return mock_stream
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "first"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "second"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.03)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_header_absent(self):
|
||||
"""state.cost_usd remains None when response has no x-total-cost header."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_extracted_even_when_stream_raises(self):
|
||||
"""cost_usd is captured in the finally block even when streaming fails."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.005"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def failing_aiter():
|
||||
raise RuntimeError("stream error")
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream.__aiter__ = lambda self: failing_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
pytest.raises(RuntimeError, match="stream error"),
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.005)
|
||||
|
||||
@@ -151,8 +151,8 @@ class CoPilotProcessor:
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
DB operations route through DatabaseManagerAsyncClient (RPC) via the
|
||||
db_accessors pattern — no direct Prisma connection is needed here.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
@@ -409,9 +410,12 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except Exception:
|
||||
raise _UserNotFoundError(user_id)
|
||||
if user.subscription_tier:
|
||||
return SubscriptionTier(user.subscription_tier)
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from claude_agent_sdk import (
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
@@ -2372,8 +2373,26 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# --- Close OTEL context ---
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens", turn_completion_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_creation_tokens",
|
||||
turn_cache_creation_tokens,
|
||||
)
|
||||
if turn_cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
|
||||
except Exception:
|
||||
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
|
||||
try:
|
||||
_otel_ctx.__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
@@ -2391,6 +2410,8 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
|
||||
@@ -4,17 +4,85 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set[asyncio.Task[None]] = set()
|
||||
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
|
||||
# fire from the event loop thread; drain_pending_cost_logs iterates the set
|
||||
# from any caller — the lock prevents RuntimeError from concurrent modification.
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
|
||||
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task[None]) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
|
||||
# block/credential in the block_cost_config or credentials_store tables).
|
||||
COPILOT_BLOCK_ID = "copilot"
|
||||
COPILOT_CREDENTIAL_ID = "copilot_system"
|
||||
|
||||
|
||||
def _copilot_block_name(log_prefix: str) -> str:
|
||||
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
|
||||
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
|
||||
if match:
|
||||
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
|
||||
tag = log_prefix.strip(" []")
|
||||
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
|
||||
|
||||
|
||||
async def persist_and_record_usage(
|
||||
*,
|
||||
@@ -26,6 +94,8 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: int = 0,
|
||||
log_prefix: str = "",
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -38,6 +108,7 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -47,12 +118,13 @@ async def persist_and_record_usage(
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if (
|
||||
no_tokens = (
|
||||
prompt_tokens <= 0
|
||||
and completion_tokens <= 0
|
||||
and cache_read_tokens <= 0
|
||||
and cache_creation_tokens <= 0
|
||||
):
|
||||
)
|
||||
if no_tokens and cost_usd is None:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
@@ -73,14 +145,14 @@ async def persist_and_record_usage(
|
||||
|
||||
if cache_read_tokens or cache_creation_tokens:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
|
||||
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
|
||||
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, cache_read={cache_read_tokens},"
|
||||
f" cache_create={cache_creation_tokens}, output={completion_tokens},"
|
||||
f" total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
|
||||
f"completion={completion_tokens}, total={total_tokens}"
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
|
||||
f" total={total_tokens}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
@@ -93,6 +165,54 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard.
|
||||
# Include entries where cost_usd is set even if token count is 0
|
||||
# (e.g. fully-cached Anthropic responses where only cache tokens
|
||||
# accumulate a charge without incrementing total_tokens).
|
||||
if user_id and (total_tokens > 0 or cost_usd is not None):
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
val = float(cost_usd)
|
||||
if math.isfinite(val) and val >= 0:
|
||||
cost_float = val
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = usd_to_microdollars(cost_float)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
tracking_type = "cost_usd"
|
||||
tracking_amount = cost_float
|
||||
else:
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
_schedule_cost_log(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
block_id=COPILOT_BLOCK_ID,
|
||||
block_name=_copilot_block_name(log_prefix),
|
||||
provider=provider,
|
||||
credential_id=COPILOT_CREDENTIAL_ID,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=model,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata={
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"source": "copilot",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@@ -4,6 +4,7 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
||||
calling conventions, session persistence, and rate-limit recording.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -279,3 +280,260 @@ class TestRateLimitRecording:
|
||||
completion_tokens=0,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlatformCostLog integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlatformCostLogging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_with_cost_usd(self):
|
||||
"""When cost_usd is provided, tracking_type should be 'cost_usd'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=_make_session(),
|
||||
user_id="user-cost",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cost_usd=0.005,
|
||||
model="gpt-4",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cost"
|
||||
assert entry.provider == "anthropic"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 200
|
||||
assert entry.output_tokens == 100
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["tracking_amount"] == 0.005
|
||||
assert entry.block_name == "copilot:SDK"
|
||||
assert entry.graph_exec_id == "sess-test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_without_cost_usd(self):
|
||||
"""When cost_usd is None, tracking_type should be 'tokens'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-tokens",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 150
|
||||
assert entry.graph_exec_id is None
|
||||
assert entry.block_name == "copilot:Baseline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_no_user_id(self):
|
||||
"""No PlatformCostLog entry when user_id is None."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_invalid_string_falls_back_to_tokens(self):
|
||||
"""Invalid cost_usd string should fall back to tokens tracking."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-invalid",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="not-a-number",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_string_number_is_parsed(self):
|
||||
"""String-encoded cost_usd (e.g. from OpenRouter) should be parsed."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-str",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="0.01",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars == 10_000
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_log_prefix_produces_copilot_block_name(self):
|
||||
"""Empty log_prefix results in block_name='copilot'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-empty",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
log_prefix="",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.block_name == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_included_in_metadata(self):
|
||||
"""Cache token counts should be present in the metadata."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cache",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=300,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.metadata["cache_read_tokens"] == 5000
|
||||
assert entry.metadata["cache_creation_tokens"] == 300
|
||||
assert entry.metadata["source"] == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_only_when_tokens_zero(self):
|
||||
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cached",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cost_usd=0.005,
|
||||
model="claude-3-5-sonnet",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Guard: total_tokens == 0 but cost_usd is set — must still log
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cached"
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 0
|
||||
assert entry.output_tokens == 0
|
||||
|
||||
@@ -142,3 +142,9 @@ def credit_db():
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
|
||||
def platform_cost_db():
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@@ -96,6 +96,7 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.platform_cost import log_platform_cost
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
@@ -332,6 +333,9 @@ class DatabaseManager(AppService):
|
||||
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = _(log_platform_cost)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
@@ -529,6 +533,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Block Descriptions ============ #
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = d.log_platform_cost
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
|
||||
@@ -104,6 +104,11 @@ class User(BaseModel):
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
# Subscription / rate-limit tier
|
||||
subscription_tier: str | None = Field(
|
||||
default=None, description="Subscription tier (FREE, PRO, BUSINESS, ENTERPRISE)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -158,6 +163,7 @@ class User(BaseModel):
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
|
||||
subscription_tier=prisma_user.subscriptionTier,
|
||||
)
|
||||
|
||||
|
||||
@@ -819,6 +825,17 @@ class RefundRequest(BaseModel):
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
ProviderCostType = Literal[
|
||||
"cost_usd", # Actual USD cost reported by the provider
|
||||
"tokens", # LLM token counts (sum of input + output)
|
||||
"characters", # Per-character billing (TTS providers)
|
||||
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
|
||||
"walltime_seconds", # Per-second billing incl. queue/polling
|
||||
"per_run", # Per-API-call billing with fixed cost
|
||||
"items", # Per-item billing (lead/organization/result count)
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
@@ -838,32 +855,39 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
# Type of the provider-reported cost/usage captured above. When set
|
||||
# by a block, resolve_tracking honors this directly instead of
|
||||
# guessing from provider name.
|
||||
provider_cost_type: Optional[ProviderCostType] = None
|
||||
# Moderation fields
|
||||
cleared_inputs: Optional[dict[str, list[str]]] = None
|
||||
cleared_outputs: Optional[dict[str, list[str]]] = None
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
"""Mutate this instance by adding another NodeExecutionStats.
|
||||
|
||||
Avoids calling model_dump() twice per merge (called on every
|
||||
merge_stats() from ~20+ blocks); reads via getattr/vars instead.
|
||||
"""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
for key in type(other).model_fields:
|
||||
value = getattr(other, key)
|
||||
if value is None:
|
||||
# Never overwrite an existing value with None
|
||||
continue
|
||||
current = getattr(self, key, None)
|
||||
if current is None:
|
||||
# Field doesn't exist yet or is None, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, dict) and isinstance(current, dict):
|
||||
current.update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(current, (int, float)):
|
||||
setattr(self, key, current + value)
|
||||
elif isinstance(value, list) and isinstance(current, list):
|
||||
current.extend(value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.data.model import HostScopedCredentials, NodeExecutionStats
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
@@ -166,3 +166,84 @@ class TestHostScopedCredentials:
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
|
||||
|
||||
class TestNodeExecutionStatsIadd:
|
||||
def test_adds_numeric_fields(self):
|
||||
a = NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
b = NodeExecutionStats(input_token_count=200, output_token_count=30)
|
||||
a += b
|
||||
assert a.input_token_count == 300
|
||||
assert a.output_token_count == 80
|
||||
|
||||
def test_none_does_not_overwrite(self):
|
||||
a = NodeExecutionStats(provider_cost=0.5, error="some error")
|
||||
b = NodeExecutionStats(provider_cost=None, error=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.5
|
||||
assert a.error == "some error"
|
||||
|
||||
def test_none_is_skipped_preserving_existing_value(self):
|
||||
a = NodeExecutionStats(input_token_count=100)
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.input_token_count == 100
|
||||
|
||||
def test_dict_fields_are_merged(self):
|
||||
a = NodeExecutionStats(
|
||||
cleared_inputs={"field1": ["val1"]},
|
||||
)
|
||||
b = NodeExecutionStats(
|
||||
cleared_inputs={"field2": ["val2"]},
|
||||
)
|
||||
a += b
|
||||
assert a.cleared_inputs == {"field1": ["val1"], "field2": ["val2"]}
|
||||
|
||||
def test_returns_self(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(input_token_count=10)
|
||||
result = a.__iadd__(b)
|
||||
assert result is a
|
||||
|
||||
def test_not_implemented_for_non_stats(self):
|
||||
a = NodeExecutionStats()
|
||||
result = a.__iadd__("not a stats") # type: ignore[arg-type]
|
||||
assert result is NotImplemented
|
||||
|
||||
def test_error_none_does_not_clear_existing_error(self):
|
||||
a = NodeExecutionStats(error="existing error")
|
||||
b = NodeExecutionStats(error=None)
|
||||
a += b
|
||||
assert a.error == "existing error"
|
||||
|
||||
def test_provider_cost_none_does_not_clear_existing_cost(self):
|
||||
a = NodeExecutionStats(provider_cost=0.05)
|
||||
b = NodeExecutionStats(provider_cost=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_accumulates_when_both_set(self):
|
||||
a = NodeExecutionStats(provider_cost=0.01)
|
||||
b = NodeExecutionStats(provider_cost=0.02)
|
||||
a += b
|
||||
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
|
||||
|
||||
def test_provider_cost_first_write_from_none(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost=0.05)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_type_first_write_from_none(self):
|
||||
"""Writing provider_cost_type into a stats with None sets it."""
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost_type="characters")
|
||||
a += b
|
||||
assert a.provider_cost_type == "characters"
|
||||
|
||||
def test_provider_cost_type_none_does_not_overwrite(self):
|
||||
"""A None provider_cost_type from other must not clear an existing value."""
|
||||
a = NodeExecutionStats(provider_cost_type="tokens")
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.provider_cost_type == "tokens"
|
||||
|
||||
390
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MICRODOLLARS_PER_USD = 1_000_000
|
||||
|
||||
# Dashboard query limits — keep in sync with the SQL queries below
|
||||
MAX_PROVIDER_ROWS = 500
|
||||
MAX_USER_ROWS = 100
|
||||
|
||||
# Default date range for dashboard queries when no start date is provided.
|
||||
# Prevents full-table scans on large deployments.
|
||||
DEFAULT_DASHBOARD_DAYS = 30
|
||||
|
||||
|
||||
def usd_to_microdollars(cost_usd: float | None) -> int | None:
|
||||
"""Convert a USD amount (float) to microdollars (int). None-safe."""
|
||||
if cost_usd is None:
|
||||
return None
|
||||
return round(cost_usd * MICRODOLLARS_PER_USD)
|
||||
|
||||
|
||||
class PlatformCostEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
block_id: str
|
||||
block_name: str
|
||||
provider: str
|
||||
credential_id: str
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
data_size: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
tracking_type: str | None = None
|
||||
tracking_amount: float | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
async def log_platform_cost(entry: PlatformCostEntry) -> None:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"PlatformCostLog"
|
||||
("id", "createdAt", "userId", "graphExecId", "nodeExecId",
|
||||
"graphId", "nodeId", "blockId", "blockName", "provider",
|
||||
"credentialId", "costMicrodollars", "inputTokens", "outputTokens",
|
||||
"dataSize", "duration", "model", "trackingType", "trackingAmount",
|
||||
"metadata")
|
||||
VALUES (
|
||||
gen_random_uuid(), NOW(), $1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14, $15, $16, $17, $18::jsonb
|
||||
)
|
||||
""",
|
||||
entry.user_id,
|
||||
entry.graph_exec_id,
|
||||
entry.node_exec_id,
|
||||
entry.graph_id,
|
||||
entry.node_id,
|
||||
entry.block_id,
|
||||
entry.block_name,
|
||||
# Normalize to lowercase so the (provider, createdAt) index is always
|
||||
# used without LOWER() on the read side.
|
||||
entry.provider.lower(),
|
||||
entry.credential_id,
|
||||
entry.cost_microdollars,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
entry.data_size,
|
||||
entry.duration,
|
||||
entry.model,
|
||||
entry.tracking_type,
|
||||
entry.tracking_amount,
|
||||
_json_or_none(entry.metadata),
|
||||
)
|
||||
|
||||
|
||||
# Bound the number of concurrent cost-log DB inserts to prevent unbounded
|
||||
# task/connection growth under sustained load or DB slowness.
|
||||
_log_semaphore = asyncio.Semaphore(50)
|
||||
|
||||
|
||||
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
"""Fire-and-forget wrapper that never raises."""
|
||||
try:
|
||||
async with _log_semaphore:
|
||||
await log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
|
||||
def _json_or_none(data: dict[str, Any] | None) -> str | None:
|
||||
if data is None:
|
||||
return None
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
def _mask_email(email: str | None) -> str | None:
|
||||
"""Mask an email address to reduce PII exposure in admin API responses.
|
||||
|
||||
Turns 'user@example.com' into 'us***@example.com'.
|
||||
Handles short local parts gracefully (e.g. 'a@b.com' → 'a***@b.com').
|
||||
"""
|
||||
if not email:
|
||||
return email
|
||||
at = email.find("@")
|
||||
if at < 0:
|
||||
return "***"
|
||||
local = email[:at]
|
||||
domain = email[at:]
|
||||
visible = local[:2] if len(local) >= 2 else local[:1]
|
||||
return f"{visible}***{domain}"
|
||||
|
||||
|
||||
class ProviderCostSummary(BaseModel):
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_duration_seconds: float = 0.0
|
||||
total_tracking_amount: float = 0.0
|
||||
request_count: int
|
||||
|
||||
|
||||
class UserCostSummary(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_name: str
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
# Provider names are normalized to lowercase at write time so a plain
|
||||
# equality check is sufficient and the (provider, createdAt) index is used.
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
"""Aggregate platform cost logs for the admin dashboard.
|
||||
|
||||
Note: by_provider rows are keyed on (provider, tracking_type). A single
|
||||
provider can therefore appear in multiple rows if it has entries with
|
||||
different billing models (e.g. "openai" with both "tokens" and "cost_usd"
|
||||
if pricing is later added for some entries). Frontend treats each row
|
||||
independently rather than as a provider primary key.
|
||||
|
||||
Defaults to the last DEFAULT_DASHBOARD_DAYS days when no start date is
|
||||
provided to avoid full-table scans on large deployments.
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_p, params_p = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
|
||||
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider", p."trackingType"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_PROVIDER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_USER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
)
|
||||
|
||||
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
|
||||
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
|
||||
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
|
||||
total_cost = sum(r["total_cost"] for r in by_provider_rows)
|
||||
total_requests = sum(r["request_count"] for r in by_provider_rows)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
total_duration_seconds=r.get("total_duration", 0.0),
|
||||
total_tracking_amount=r.get("total_tracking_amount", 0.0),
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_provider_rows
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_user_rows
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
)
|
||||
|
||||
|
||||
async def get_platform_cost_logs(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
|
||||
count_rows, rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
),
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return logs, total
|
||||
266
autogpt_platform/backend/backend/data/platform_cost_test.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Unit tests for helpers and async functions in platform_cost module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
_build_where,
|
||||
_json_or_none,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
log_platform_cost,
|
||||
log_platform_cost_safe,
|
||||
)
|
||||
|
||||
|
||||
class TestJsonOrNone:
|
||||
def test_returns_none_for_none(self):
|
||||
assert _json_or_none(None) is None
|
||||
|
||||
def test_returns_json_string_for_dict(self):
|
||||
result = _json_or_none({"key": "value", "num": 42})
|
||||
assert result is not None
|
||||
assert '"key"' in result
|
||||
assert '"value"' in result
|
||||
|
||||
def test_returns_json_for_empty_dict(self):
|
||||
assert _json_or_none({}) == "{}"
|
||||
|
||||
|
||||
class TestBuildWhere:
|
||||
def test_no_filters_returns_true(self):
|
||||
sql, params = _build_where(None, None, None, None)
|
||||
assert sql == "TRUE"
|
||||
assert params == []
|
||||
|
||||
def test_start_only(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None)
|
||||
assert '"createdAt" >= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_end_only(self):
|
||||
dt = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(None, dt, None, None)
|
||||
assert '"createdAt" <= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_provider_only(self):
|
||||
# Provider names are normalized to lowercase at write time, so the
|
||||
# filter uses a plain equality check. The input is also lowercased so
|
||||
# "OpenAI" and "openai" both match stored rows.
|
||||
sql, params = _build_where(None, None, "OpenAI", None)
|
||||
assert '"provider" = $1' in sql
|
||||
assert params == ["openai"]
|
||||
|
||||
def test_user_id_only(self):
|
||||
sql, params = _build_where(None, None, None, "user-123")
|
||||
assert '"userId" = $1' in sql
|
||||
assert params == ["user-123"]
|
||||
|
||||
def test_all_filters(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(start, end, "Anthropic", "u1")
|
||||
assert "$1" in sql
|
||||
assert "$2" in sql
|
||||
assert "$3" in sql
|
||||
assert "$4" in sql
|
||||
assert len(params) == 4
|
||||
# Provider is lowercased at filter time to match stored lowercase values.
|
||||
assert params == [start, end, "anthropic", "u1"]
|
||||
|
||||
def test_table_alias(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None, table_alias="p")
|
||||
assert 'p."createdAt"' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_clauses_joined_with_and(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, _ = _build_where(start, end, None, None)
|
||||
assert " AND " in sql
|
||||
|
||||
|
||||
def _make_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"block_id": "block-1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"credential_id": "cred-1",
|
||||
**overrides,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestLogPlatformCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_execute_raw_with_schema(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry(
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cost_microdollars=5000,
|
||||
model="gpt-4",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
await log_platform_cost(entry)
|
||||
mock_exec.assert_awaited_once()
|
||||
args = mock_exec.call_args
|
||||
assert args[0][1] == "user-1" # user_id is first param
|
||||
assert args[0][6] == "block-1" # block_id
|
||||
assert args[0][7] == "TestBlock" # block_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_none_passes_none(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry(metadata=None)
|
||||
await log_platform_cost(entry)
|
||||
args = mock_exec.call_args
|
||||
assert args[0][-1] is None # last arg is metadata json
|
||||
|
||||
|
||||
class TestLogPlatformCostSafe:
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_on_error(self):
|
||||
with patch(
|
||||
"backend.data.platform_cost.execute_raw_with_schema",
|
||||
new=AsyncMock(side_effect=RuntimeError("DB down")),
|
||||
):
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_succeeds_when_no_error(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
mock_exec.assert_awaited_once()
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dashboard_with_data(self):
|
||||
provider_rows = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"total_duration": 10.5,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
user_rows = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
|
||||
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 5000
|
||||
assert dashboard.total_requests == 3
|
||||
assert dashboard.total_users == 1
|
||||
assert len(dashboard.by_provider) == 1
|
||||
assert dashboard.by_provider[0].provider == "openai"
|
||||
assert dashboard.by_provider[0].tracking_type == "tokens"
|
||||
assert dashboard.by_provider[0].total_duration_seconds == 10.5
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 0
|
||||
assert dashboard.total_requests == 0
|
||||
assert dashboard.total_users == 0
|
||||
assert dashboard.by_provider == []
|
||||
assert dashboard.by_user == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
assert mock_query.await_count == 3
|
||||
first_call_sql = mock_query.call_args_list[0][0][0]
|
||||
assert "createdAt" in first_call_sql
|
||||
|
||||
|
||||
class TestGetPlatformCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_and_total(self):
|
||||
count_rows = [{"cnt": 1}]
|
||||
log_rows = [
|
||||
{
|
||||
"id": "log-1",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"graph_exec_id": "g1",
|
||||
"node_exec_id": "n1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 5000,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"duration": 1.5,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=1, page_size=10)
|
||||
assert total == 1
|
||||
assert len(logs) == 1
|
||||
assert logs[0].id == "log-1"
|
||||
assert logs[0].provider == "openai"
|
||||
assert logs[0].model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_data(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
assert logs == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_offset(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
assert total == 100
|
||||
second_call_args = mock_query.call_args_list[1][0]
|
||||
assert 25 in second_call_args # page_size
|
||||
assert 50 in second_call_args # offset = (3-1) * 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_count_returns_zero(self):
|
||||
mock_query = AsyncMock(side_effect=[[], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
291
autogpt_platform/backend/backend/executor/cost_tracking.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Helpers for platform cost tracking on system-credential block executions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.copilot.token_tracking import _pending_log_tasks as _copilot_tasks
|
||||
from backend.copilot.token_tracking import (
|
||||
_pending_log_tasks_lock as _copilot_tasks_lock,
|
||||
)
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import is_system_credential
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider groupings by billing model — used when the block didn't explicitly
|
||||
# declare stats.provider_cost_type and we fall back to provider-name
|
||||
# heuristics. Values match ProviderName enum values.
|
||||
_CHARACTER_BILLED_PROVIDERS = frozenset(
|
||||
{ProviderName.D_ID.value, ProviderName.ELEVENLABS.value}
|
||||
)
|
||||
_WALLTIME_BILLED_PROVIDERS = frozenset(
|
||||
{
|
||||
ProviderName.FAL.value,
|
||||
ProviderName.REVID.value,
|
||||
ProviderName.REPLICATE.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Hold strong references to in-flight log tasks so the event loop doesn't
|
||||
# garbage-collect them mid-execution. Tasks remove themselves on completion.
|
||||
# _pending_log_tasks_lock guards all reads and writes: worker threads call
|
||||
# discard() via done callbacks while drain_pending_cost_logs() iterates.
|
||||
_pending_log_tasks: set[asyncio.Task] = set()
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads. Key by loop instance
|
||||
# so each executor worker thread gets its own semaphore.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
"""Await all in-flight cost log tasks with a timeout.
|
||||
|
||||
Drains both the executor cost log tasks (_pending_log_tasks in this module,
|
||||
used for block execution cost tracking via DatabaseManagerAsyncClient) and
|
||||
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
|
||||
copilot LLM turns via platform_cost_db()).
|
||||
|
||||
Call this during graceful shutdown to flush pending INSERT tasks before
|
||||
the process exits. Tasks that don't complete within `timeout` seconds are
|
||||
abandoned and their failures are already logged by _safe_log.
|
||||
"""
|
||||
# asyncio.wait() requires all tasks to belong to the running event loop.
|
||||
# _pending_log_tasks is shared across executor worker threads (each with
|
||||
# its own loop), so filter to only tasks owned by the current loop.
|
||||
# Acquire the lock to take a consistent snapshot (worker threads call
|
||||
# discard() via done callbacks concurrently with this iteration).
|
||||
current_loop = asyncio.get_running_loop()
|
||||
with _pending_log_tasks_lock:
|
||||
all_pending = [t for t in _pending_log_tasks if t.get_loop() is current_loop]
|
||||
if all_pending:
|
||||
logger.info("Draining %d executor cost log task(s)", len(all_pending))
|
||||
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d executor cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
|
||||
with _copilot_tasks_lock:
|
||||
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
|
||||
if copilot_pending:
|
||||
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
|
||||
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d copilot cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def _schedule_log(
|
||||
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
|
||||
) -> None:
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await db_client.log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
def _extract_model_name(raw: str | dict | None) -> str | None:
|
||||
"""Return a string model name from a block input field, or None.
|
||||
|
||||
Handles str (returned as-is), dict (e.g. an enum wrapper, skipped), and
|
||||
None (no model field). Unexpected types are coerced to str as a fallback.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
return None
|
||||
return str(raw)
|
||||
|
||||
|
||||
def resolve_tracking(
|
||||
provider: str,
|
||||
stats: NodeExecutionStats,
|
||||
input_data: dict[str, Any],
|
||||
) -> tuple[str, float]:
|
||||
"""Return (tracking_type, tracking_amount) based on provider billing model.
|
||||
|
||||
Preference order:
|
||||
1. Block-declared: if the block set `provider_cost_type` on its stats,
|
||||
honor it directly (paired with `provider_cost` as the amount).
|
||||
2. Heuristic fallback: infer from `provider_cost`/token counts, then
|
||||
from provider name for per-character / per-second billing.
|
||||
"""
|
||||
# 1. Block explicitly declared its cost type (only when an amount is present)
|
||||
if stats.provider_cost_type and stats.provider_cost is not None:
|
||||
return stats.provider_cost_type, stats.provider_cost
|
||||
|
||||
# 2. Provider returned actual USD cost (OpenRouter, Exa)
|
||||
if stats.provider_cost is not None:
|
||||
return "cost_usd", stats.provider_cost
|
||||
|
||||
# 3. LLM providers: track by tokens
|
||||
if stats.input_token_count or stats.output_token_count:
|
||||
return "tokens", float(
|
||||
(stats.input_token_count or 0) + (stats.output_token_count or 0)
|
||||
)
|
||||
|
||||
# 4. Provider-specific billing heuristics
|
||||
|
||||
# TTS: billed per character of input text
|
||||
if provider == ProviderName.UNREAL_SPEECH.value:
|
||||
text = input_data.get("text", "")
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# D-ID + ElevenLabs voice: billed per character of script
|
||||
if provider in _CHARACTER_BILLED_PROVIDERS:
|
||||
text = (
|
||||
input_data.get("script_input", "")
|
||||
or input_data.get("text", "")
|
||||
or input_data.get("script", "") # VideoNarrationBlock uses `script`
|
||||
)
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# E2B: billed per second of sandbox time
|
||||
if provider == ProviderName.E2B.value:
|
||||
return "sandbox_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Video/image gen: walltime includes queue + generation + polling
|
||||
if provider in _WALLTIME_BILLED_PROVIDERS:
|
||||
return "walltime_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
|
||||
# All billed per API call - count 1 per block execution.
|
||||
return "per_run", 1.0
|
||||
|
||||
|
||||
async def log_system_credential_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
block: Block,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
) -> None:
|
||||
"""Check if a system credential was used and log the platform cost.
|
||||
|
||||
Routes through DatabaseManagerAsyncClient so the write goes via the
|
||||
message-passing DB service rather than calling Prisma directly (which
|
||||
is not connected in the executor process).
|
||||
|
||||
Logs only the first matching system credential field (one log per
|
||||
execution). Any unexpected error is caught and logged — cost logging
|
||||
is strictly best-effort and must never disrupt block execution.
|
||||
|
||||
Note: costMicrodollars is left null for providers that don't return
|
||||
a USD cost. The credit_cost in metadata captures our internal credit
|
||||
charge as a proxy.
|
||||
"""
|
||||
try:
|
||||
if node_exec.execution_context.dry_run:
|
||||
return
|
||||
|
||||
input_data = node_exec.inputs
|
||||
input_model = cast(type[BlockSchema], block.input_schema)
|
||||
|
||||
for field_name in input_model.get_credentials_fields():
|
||||
cred_data = input_data.get(field_name)
|
||||
if not cred_data or not isinstance(cred_data, dict):
|
||||
continue
|
||||
cred_id = cred_data.get("id", "")
|
||||
if not cred_id or not is_system_credential(cred_id):
|
||||
continue
|
||||
|
||||
model_name = _extract_model_name(input_data.get("model"))
|
||||
|
||||
credit_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
|
||||
provider_name = cred_data.get("provider", "unknown")
|
||||
tracking_type, tracking_amount = resolve_tracking(
|
||||
provider=provider_name,
|
||||
stats=stats,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Only treat provider_cost as USD when the tracking type says so.
|
||||
# For other types (items, characters, per_run, ...) the
|
||||
# provider_cost field holds the raw amount, not a dollar value.
|
||||
# Use tracking_amount (the normalized value from resolve_tracking)
|
||||
# rather than raw stats.provider_cost to avoid unit mismatches.
|
||||
cost_microdollars = None
|
||||
if tracking_type == "cost_usd":
|
||||
cost_microdollars = usd_to_microdollars(tracking_amount)
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
}
|
||||
if credit_cost is not None:
|
||||
meta["credit_cost"] = credit_cost
|
||||
if stats.provider_cost is not None:
|
||||
# Use 'provider_cost_raw' — the value's unit varies by tracking
|
||||
# type (USD for cost_usd, count for items/characters/per_run, etc.)
|
||||
meta["provider_cost_raw"] = stats.provider_cost
|
||||
|
||||
_schedule_log(
|
||||
db_client,
|
||||
PlatformCostEntry(
|
||||
user_id=node_exec.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block_name=block.name,
|
||||
provider=provider_name,
|
||||
credential_id=cred_id,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=stats.input_token_count,
|
||||
output_tokens=stats.output_token_count,
|
||||
data_size=stats.output_size if stats.output_size > 0 else None,
|
||||
duration=stats.walltime if stats.walltime > 0 else None,
|
||||
model=model_name,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata=meta,
|
||||
),
|
||||
)
|
||||
return # One log per execution is enough
|
||||
except Exception:
|
||||
logger.exception("log_system_credential_cost failed unexpectedly")
|
||||
@@ -45,6 +45,10 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util import json
|
||||
@@ -692,6 +696,15 @@ class ExecutionProcessor:
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
# Log platform cost if system credentials were used (only on success)
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(
|
||||
node_exec=node_exec,
|
||||
block=node.block,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
@async_time_measured
|
||||
@@ -2044,14 +2057,23 @@ class ExecutionManager(AppProcess):
|
||||
prefix + " [cancel-consumer]",
|
||||
)
|
||||
|
||||
# Drain any in-flight cost log tasks before exit so we don't silently
|
||||
# drop INSERT operations during deployments.
|
||||
loop = getattr(self, "node_execution_loop", None)
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
drain_pending_cost_logs(), loop
|
||||
).result(timeout=10)
|
||||
logger.info(f"{prefix} ✅ Cost log tasks drained")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to drain cost log tasks: {e}")
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
"""Unit tests for resolve_tracking and log_system_credential_cost."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext, NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.cost_tracking import log_system_credential_cost, resolve_tracking
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveTracking:
|
||||
def _stats(self, **overrides: Any) -> NodeExecutionStats:
|
||||
return NodeExecutionStats(**overrides)
|
||||
|
||||
def test_provider_cost_returns_cost_usd(self):
|
||||
stats = self._stats(provider_cost=0.0042)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0042
|
||||
|
||||
def test_token_counts_return_tokens(self):
|
||||
stats = self._stats(input_token_count=300, output_token_count=100)
|
||||
tt, amt = resolve_tracking("anthropic", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 400.0
|
||||
|
||||
def test_token_counts_only_input(self):
|
||||
stats = self._stats(input_token_count=500)
|
||||
tt, amt = resolve_tracking("groq", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 500.0
|
||||
|
||||
def test_unreal_speech_returns_characters(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": "Hello world"})
|
||||
assert tt == "characters"
|
||||
assert amt == 11.0
|
||||
|
||||
def test_unreal_speech_empty_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": ""})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_unreal_speech_non_string_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": 123})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_d_id_uses_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("d_id", stats, {"script_input": "Hello"})
|
||||
assert tt == "characters"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_elevenlabs_uses_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Say this"})
|
||||
assert tt == "characters"
|
||||
assert amt == 8.0
|
||||
|
||||
def test_elevenlabs_fallback_to_text_when_no_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Fallback text"})
|
||||
assert tt == "characters"
|
||||
assert amt == 13.0
|
||||
|
||||
def test_elevenlabs_uses_script_field(self):
|
||||
"""VideoNarrationBlock (elevenlabs) uses `script` field, not script_input/text."""
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"script": "Narration"})
|
||||
assert tt == "characters"
|
||||
assert amt == 9.0
|
||||
|
||||
def test_block_declared_cost_type_items(self):
|
||||
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
|
||||
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "items"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_block_declared_cost_type_characters(self):
|
||||
"""TTS block can declare characters directly, bypassing input_data lookup."""
|
||||
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {})
|
||||
assert tt == "characters"
|
||||
assert amt == 42.0
|
||||
|
||||
def test_block_declared_cost_type_wins_over_tokens(self):
|
||||
"""provider_cost_type takes precedence over token-based heuristic."""
|
||||
stats = self._stats(
|
||||
provider_cost=1.0,
|
||||
provider_cost_type="per_run",
|
||||
input_token_count=500,
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_e2b_returns_sandbox_seconds(self):
|
||||
stats = self._stats(walltime=45.123)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 45.123
|
||||
|
||||
def test_e2b_no_walltime(self):
|
||||
stats = self._stats(walltime=0)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_fal_returns_walltime(self):
|
||||
stats = self._stats(walltime=12.5)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 12.5
|
||||
|
||||
def test_revid_returns_walltime(self):
|
||||
stats = self._stats(walltime=60.0)
|
||||
tt, amt = resolve_tracking("revid", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 60.0
|
||||
|
||||
def test_replicate_returns_walltime(self):
|
||||
stats = self._stats(walltime=30.0)
|
||||
tt, amt = resolve_tracking("replicate", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 30.0
|
||||
|
||||
def test_unknown_provider_returns_per_run(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_provider_cost_takes_precedence_over_tokens(self):
|
||||
stats = self._stats(
|
||||
provider_cost=0.01, input_token_count=500, output_token_count=200
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.01
|
||||
|
||||
def test_provider_cost_zero_is_not_none(self):
|
||||
"""provider_cost=0.0 is falsy but should still be tracked as cost_usd
|
||||
(e.g. free-tier or fully-cached responses from OpenRouter)."""
|
||||
stats = self._stats(provider_cost=0.0)
|
||||
tt, amt = resolve_tracking("open_router", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_tokens_take_precedence_over_provider_specific(self):
|
||||
stats = self._stats(input_token_count=100, walltime=10.0)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 100.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# log_system_credential_cost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_client() -> MagicMock:
|
||||
db_client = MagicMock()
|
||||
db_client.log_platform_cost = AsyncMock()
|
||||
return db_client
|
||||
|
||||
|
||||
def _make_block(has_credentials: bool = True) -> MagicMock:
|
||||
block = MagicMock()
|
||||
block.name = "TestBlock"
|
||||
input_schema = MagicMock()
|
||||
if has_credentials:
|
||||
input_schema.get_credentials_fields.return_value = {"credentials": MagicMock()}
|
||||
else:
|
||||
input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema = input_schema
|
||||
return block
|
||||
|
||||
|
||||
def _make_node_exec(
|
||||
inputs: dict | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> NodeExecutionEntry:
|
||||
return NodeExecutionEntry(
|
||||
user_id="user-1",
|
||||
graph_exec_id="gx-1",
|
||||
graph_id="g-1",
|
||||
graph_version=1,
|
||||
node_exec_id="nx-1",
|
||||
node_id="n-1",
|
||||
block_id="b-1",
|
||||
inputs=inputs or {},
|
||||
execution_context=ExecutionContext(dry_run=dry_run),
|
||||
)
|
||||
|
||||
|
||||
class TestLogSystemCredentialCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_dry_run(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(dry_run=True)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_no_credential_fields(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block(has_credentials=False)
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_cred_data_missing(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_not_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with patch(
|
||||
"backend.executor.cost_tracking.is_system_credential",
|
||||
return_value=False,
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "user-cred-123", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(10, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-1", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.user_id == "user-1"
|
||||
assert entry.provider == "openai"
|
||||
assert entry.block_name == "TestBlock"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.input_tokens == 500
|
||||
assert entry.output_tokens == 200
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 700.0
|
||||
assert entry.metadata["credit_cost"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_provider_cost(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-2", "provider": "open_router"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["provider_cost_raw"] == 0.0015
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_enum_converted_to_str(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
from enum import Enum
|
||||
|
||||
class FakeModel(Enum):
|
||||
GPT4 = "gpt-4"
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": FakeModel.GPT4,
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model == "FakeModel.GPT4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_dict_becomes_none(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": {"nested": "value"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_when_block_usage_cost_raises(self):
|
||||
"""log_system_credential_cost must swallow exceptions from block_usage_cost."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
side_effect=RuntimeError("pricing lookup failed"),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
# Should not raise — outer except must catch block_usage_cost error
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_instead_of_int_for_microdollars(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
# 0.0015 * 1_000_000 = 1499.9999999... with float math
|
||||
# round() should give 1500, int() would give 1499
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_run_metadata_has_no_provider_cost_raw(self):
|
||||
"""For per-run providers (google_maps etc), provider_cost_raw is absent
|
||||
from metadata since stats.provider_cost is None."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "google_maps"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats() # no provider_cost
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.tracking_type == "per_run"
|
||||
assert "provider_cost_raw" not in (entry.metadata or {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# merge_stats accumulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeStats:
|
||||
"""Tests for NodeExecutionStats accumulation via += (used by Block.merge_stats)."""
|
||||
|
||||
def test_accumulates_output_size(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
stats += NodeExecutionStats(output_size=25)
|
||||
assert stats.output_size == 35
|
||||
|
||||
def test_accumulates_tokens(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
stats += NodeExecutionStats(input_token_count=200, output_token_count=150)
|
||||
assert stats.input_token_count == 300
|
||||
assert stats.output_token_count == 200
|
||||
|
||||
def test_preserves_provider_cost(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.005)
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
assert stats.provider_cost == 0.005
|
||||
assert stats.output_size == 10
|
||||
|
||||
def test_provider_cost_accumulates(self):
|
||||
"""Multiple merge_stats with provider_cost should sum (multi-round
|
||||
tool-calling in copilot / retries can report cost separately)."""
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.001)
|
||||
stats += NodeExecutionStats(provider_cost=0.002)
|
||||
stats += NodeExecutionStats(provider_cost=0.003)
|
||||
assert stats.provider_cost == pytest.approx(0.006)
|
||||
|
||||
def test_provider_cost_none_does_not_overwrite(self):
|
||||
"""A None provider_cost must not wipe a previously-set value."""
|
||||
stats = NodeExecutionStats(provider_cost=0.01)
|
||||
stats += NodeExecutionStats() # provider_cost=None by default
|
||||
assert stats.provider_cost == 0.01
|
||||
|
||||
def test_provider_cost_type_last_write_wins(self):
|
||||
"""provider_cost_type is a Literal — last set value wins on merge."""
|
||||
stats = NodeExecutionStats(provider_cost_type="tokens")
|
||||
stats += NodeExecutionStats(provider_cost_type="items")
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_node_execution -> log_system_credential_cost integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManagerCostTrackingIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_called_with_accumulated_stats(self):
|
||||
"""Verify that log_system_credential_cost receives stats that could
|
||||
have been accumulated by merge_stats across multiple yield steps."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10, input_token_count=100)
|
||||
stats += NodeExecutionStats(output_size=25, input_token_count=200)
|
||||
|
||||
assert stats.output_size == 35
|
||||
assert stats.input_token_count == 300
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-acc", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.input_tokens == 300
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 300.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_status_is_failed(self):
|
||||
"""Manager only calls log_system_credential_cost on COMPLETED status.
|
||||
|
||||
This test verifies the guard condition `if status == COMPLETED` directly:
|
||||
calling log_system_credential_cost only happens on success, never on
|
||||
FAILED or ERROR executions.
|
||||
"""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(
|
||||
inputs={"credentials": {"id": "sys-cred", "provider": "openai"}}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=100)
|
||||
|
||||
# Simulate the manager guard: only call on COMPLETED
|
||||
status = ExecutionStatus.FAILED
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
@@ -0,0 +1,42 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformCostLog" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT,
|
||||
"graphExecId" TEXT,
|
||||
"nodeExecId" TEXT,
|
||||
"graphId" TEXT,
|
||||
"nodeId" TEXT,
|
||||
"blockId" TEXT NOT NULL,
|
||||
"blockName" TEXT NOT NULL,
|
||||
"provider" TEXT NOT NULL,
|
||||
"credentialId" TEXT NOT NULL,
|
||||
"costMicrodollars" BIGINT,
|
||||
"inputTokens" INTEGER,
|
||||
"outputTokens" INTEGER,
|
||||
"dataSize" INTEGER,
|
||||
"duration" DOUBLE PRECISION,
|
||||
"model" TEXT,
|
||||
"trackingType" TEXT,
|
||||
"metadata" JSONB,
|
||||
|
||||
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_userId_createdAt_idx" ON "PlatformCostLog"("userId", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_createdAt_idx" ON "PlatformCostLog"("provider", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PlatformCostLog" ADD COLUMN "trackingAmount" DOUBLE PRECISION;
|
||||
@@ -75,6 +75,8 @@ model User {
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
|
||||
PlatformCostLogs PlatformCostLog[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
@@ -815,6 +817,45 @@ model CreditRefundRequest {
|
||||
@@index([userId, transactionKey])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////// Platform Cost Tracking TABLES //////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model PlatformCostLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: SetNull)
|
||||
graphExecId String?
|
||||
nodeExecId String?
|
||||
graphId String?
|
||||
nodeId String?
|
||||
blockId String
|
||||
blockName String
|
||||
provider String
|
||||
credentialId String
|
||||
|
||||
// Cost in microdollars (1 USD = 1,000,000). Null if unknown.
|
||||
costMicrodollars BigInt?
|
||||
|
||||
inputTokens Int?
|
||||
outputTokens Int?
|
||||
dataSize Int? // bytes
|
||||
duration Float? // seconds
|
||||
model String?
|
||||
trackingType String? // e.g. "cost_usd", "tokens", "characters", "items", "per_run", "sandbox_seconds", "walltime_seconds"
|
||||
trackingAmount Float? // Amount in the unit implied by trackingType
|
||||
metadata Json?
|
||||
|
||||
@@index([userId, createdAt])
|
||||
@@index([provider, createdAt])
|
||||
@@index([createdAt])
|
||||
@@index([graphExecId])
|
||||
@@index([provider, trackingType])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// Store TABLES ///////////////////////////
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Gauge } from "@phosphor-icons/react/dist/ssr";
|
||||
import {
|
||||
Users,
|
||||
CurrencyDollar,
|
||||
MagnifyingGlass,
|
||||
Gauge,
|
||||
Receipt,
|
||||
FileText,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -15,18 +21,23 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
icon: <CurrencyDollar className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
icon: <MagnifyingGlass className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Rate Limits",
|
||||
href: "/admin/rate-limits",
|
||||
icon: <Gauge className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Platform Costs",
|
||||
href: "/admin/platform-costs",
|
||||
icon: <Receipt className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
cleanup,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { PlatformCostContent } from "../components/PlatformCostContent";
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import type { PlatformCostLogsResponse } from "@/app/api/__generated__/models/platformCostLogsResponse";
|
||||
|
||||
// Mock the generated Orval hooks so tests don't hit the network
|
||||
const mockUseGetDashboard = vi.fn();
|
||||
const mockUseGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
useGetV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockUseGetDashboard(...args),
|
||||
useGetV2GetPlatformCostLogs: (...args: unknown[]) => mockUseGetLogs(...args),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetDashboard.mockReset();
|
||||
mockUseGetLogs.mockReset();
|
||||
});
|
||||
|
||||
const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
|
||||
const emptyLogs: PlatformCostLogsResponse = {
|
||||
logs: [],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 0,
|
||||
total_pages: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 60,
|
||||
},
|
||||
{
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 40,
|
||||
},
|
||||
],
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const logsWithData: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-1",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
graph_exec_id: "gx-123",
|
||||
node_exec_id: "nx-456",
|
||||
block_name: "LLMBlock",
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
cost_microdollars: 5000,
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
duration: 1.5,
|
||||
model: "gpt-4",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
|
||||
function renderComponent(searchParams = {}) {
|
||||
return render(<PlatformCostContent searchParams={searchParams} />);
|
||||
}
|
||||
|
||||
describe("PlatformCostContent", () => {
|
||||
it("shows loading state initially", () => {
|
||||
mockUseGetDashboard.mockReturnValue({ data: undefined, isLoading: true });
|
||||
mockUseGetLogs.mockReturnValue({ data: undefined, isLoading: true });
|
||||
renderComponent();
|
||||
// Loading state renders Skeleton placeholders (animate-pulse divs) instead of content
|
||||
expect(screen.queryByText("Loading...")).toBeNull();
|
||||
// Summary cards and table content are not yet shown
|
||||
expect(screen.queryByText("Known Cost")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders empty dashboard", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: emptyLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(2);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders dashboard with provider data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("$5.0000")).toBeDefined();
|
||||
expect(screen.getByText("100")).toBeDefined();
|
||||
expect(screen.getByText("5")).toBeDefined();
|
||||
expect(screen.getByText("openai")).toBeDefined();
|
||||
expect(screen.getByText("google_maps")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tracking type badges", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("tokens")).toBeDefined();
|
||||
expect(screen.getByText("per_run")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows error state on fetch failure", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Network error")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tab buttons", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("By Provider")).toBeDefined();
|
||||
expect(screen.getByText("By User")).toBeDefined();
|
||||
expect(screen.getByText("Raw Logs")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders summary cards with correct labels", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Start Date")).toBeDefined();
|
||||
expect(screen.getByText("End Date")).toBeDefined();
|
||||
expect(screen.getAllByText(/Provider/i).length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("User ID")).toBeDefined();
|
||||
expect(screen.getByText("Apply")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders by-user tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders logs tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders no logs message when empty", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("No logs found")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows pagination when multiple pages", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
const multiPageLogs: PlatformCostLogsResponse = {
|
||||
logs: logsWithData.logs,
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 200,
|
||||
total_pages: 4,
|
||||
},
|
||||
};
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: multiPageLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Previous")).toBeDefined();
|
||||
expect(screen.getByText("Next")).toBeDefined();
|
||||
expect(screen.getByText(/Page 1 of 4/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table with unknown email", async () => {
|
||||
const dashWithNullEmail: PlatformCostDashboard = {
|
||||
...dashboardWithData,
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-2",
|
||||
email: null,
|
||||
total_cost_microdollars: 1000,
|
||||
total_input_tokens: 100,
|
||||
total_output_tokens: 50,
|
||||
request_count: 5,
|
||||
},
|
||||
],
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashWithNullEmail,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Unknown")).toBeDefined();
|
||||
});
|
||||
|
||||
it("by-user tab content visible when tab=by-user param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
// overview tab content should not be visible
|
||||
expect(screen.queryByText("openai")).toBeNull();
|
||||
});
|
||||
|
||||
it("logs tab content visible when tab=logs param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders log with null user as dash", async () => {
|
||||
const logWithNullUser: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-2",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: null,
|
||||
email: null,
|
||||
graph_exec_id: null,
|
||||
node_exec_id: null,
|
||||
block_name: "copilot:SDK",
|
||||
provider: "anthropic",
|
||||
tracking_type: "cost_usd",
|
||||
cost_microdollars: 15000,
|
||||
input_tokens: null,
|
||||
output_tokens: null,
|
||||
duration: null,
|
||||
model: "claude-opus-4-20250514",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logWithNullUser,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("copilot:SDK")).toBeDefined();
|
||||
expect(screen.getByText("anthropic")).toBeDefined();
|
||||
// null email + null user_id renders as "-" in the User column; multiple
|
||||
// other cells (tokens, duration, session) also render "-", so use
|
||||
// getAllByText to avoid the single-match constraint.
|
||||
expect(screen.getAllByText("-").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,87 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
const mockGetDashboard = vi.fn();
|
||||
const mockGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
getV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockGetDashboard(...args),
|
||||
getV2GetPlatformCostLogs: (...args: unknown[]) => mockGetLogs(...args),
|
||||
}));
|
||||
|
||||
import { getPlatformCostDashboard, getPlatformCostLogs } from "../actions";
|
||||
|
||||
describe("getPlatformCostDashboard", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { total_cost_microdollars: 1000, total_requests: 5 };
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("returns undefined on non-200", async () => {
|
||||
mockGetDashboard.mockResolvedValue({ status: 401 });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("passes filter params to API", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "2026-01-01T00:00:00",
|
||||
end: "2026-06-01T00:00:00",
|
||||
provider: "openai",
|
||||
user_id: "user-1",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-01-01T00:00:00");
|
||||
expect(params.end).toBe("2026-06-01T00:00:00");
|
||||
expect(params.provider).toBe("openai");
|
||||
expect(params.user_id).toBe("user-1");
|
||||
});
|
||||
|
||||
it("passes undefined for empty filter strings", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBeUndefined();
|
||||
expect(params.provider).toBeUndefined();
|
||||
expect(params.user_id).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPlatformCostLogs", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { logs: [], pagination: { current_page: 1 } };
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostLogs();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("passes page and page_size", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ page: 3, page_size: 25 });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.page).toBe(3);
|
||||
expect(params.page_size).toBe(25);
|
||||
});
|
||||
|
||||
it("passes start date string through to API", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ start: "2026-03-01T00:00:00" });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-03-01T00:00:00");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,300 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
toDateOrUndefined,
|
||||
formatMicrodollars,
|
||||
formatTokens,
|
||||
formatDuration,
|
||||
estimateCostForRow,
|
||||
trackingValue,
|
||||
toLocalInput,
|
||||
toUtcIso,
|
||||
} from "../helpers";
|
||||
|
||||
function makeRow(overrides: Partial<ProviderCostSummary>): ProviderCostSummary {
|
||||
return {
|
||||
provider: "openai",
|
||||
tracking_type: null,
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 0,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("toDateOrUndefined", () => {
|
||||
it("returns undefined for empty string", () => {
|
||||
expect(toDateOrUndefined("")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for undefined", () => {
|
||||
expect(toDateOrUndefined(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for invalid date string", () => {
|
||||
expect(toDateOrUndefined("not-a-date")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns a Date for a valid ISO string", () => {
|
||||
const result = toDateOrUndefined("2026-01-15T00:00:00Z");
|
||||
expect(result).toBeInstanceOf(Date);
|
||||
expect(result!.toISOString()).toBe("2026-01-15T00:00:00.000Z");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatMicrodollars", () => {
|
||||
it("formats zero", () => {
|
||||
expect(formatMicrodollars(0)).toBe("$0.0000");
|
||||
});
|
||||
|
||||
it("formats a small amount", () => {
|
||||
expect(formatMicrodollars(50_000)).toBe("$0.0500");
|
||||
});
|
||||
|
||||
it("formats one dollar", () => {
|
||||
expect(formatMicrodollars(1_000_000)).toBe("$1.0000");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatTokens", () => {
|
||||
it("formats small numbers as-is", () => {
|
||||
expect(formatTokens(500)).toBe("500");
|
||||
});
|
||||
|
||||
it("formats thousands with K suffix", () => {
|
||||
expect(formatTokens(1_500)).toBe("1.5K");
|
||||
});
|
||||
|
||||
it("formats millions with M suffix", () => {
|
||||
expect(formatTokens(2_500_000)).toBe("2.5M");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatDuration", () => {
|
||||
it("formats seconds", () => {
|
||||
expect(formatDuration(30)).toBe("30.0s");
|
||||
});
|
||||
|
||||
it("formats minutes", () => {
|
||||
expect(formatDuration(90)).toBe("1.5m");
|
||||
});
|
||||
|
||||
it("formats hours", () => {
|
||||
expect(formatDuration(5400)).toBe("1.5h");
|
||||
});
|
||||
});
|
||||
|
||||
describe("estimateCostForRow", () => {
|
||||
it("returns microdollars directly for cost_usd tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 500_000,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(500_000);
|
||||
});
|
||||
|
||||
it("returns reported cost for token tracking when cost > 0", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 100_000,
|
||||
total_input_tokens: 1000,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(100_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for token tracking with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
// 1000 tokens / 1000 * 0.005 USD * 1_000_000 = 5000
|
||||
expect(estimateCostForRow(row, {})).toBe(5000);
|
||||
});
|
||||
|
||||
it("returns null for unknown token provider with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "unknown_provider",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("uses per-run override when provided", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
request_count: 10,
|
||||
});
|
||||
// override = 0.05 * 10 * 1_000_000 = 500_000
|
||||
expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe(
|
||||
500_000,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses default per-run cost when no override", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: null,
|
||||
request_count: 5,
|
||||
});
|
||||
// 0.032 * 5 * 1_000_000 = 160_000
|
||||
expect(estimateCostForRow(row, {})).toBe(160_000);
|
||||
});
|
||||
|
||||
it("returns null for unknown per_run provider", () => {
|
||||
const row = makeRow({
|
||||
provider: "totally_unknown",
|
||||
tracking_type: "per_run",
|
||||
request_count: 3,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for duration tracking with no rate and no cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "duration_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 100,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "elevenlabs",
|
||||
tracking_type: "characters",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 2000,
|
||||
});
|
||||
// 2000 chars / 1000 * 0.18 USD * 1_000_000 = 360_000
|
||||
expect(estimateCostForRow(row, {})).toBe(360_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for items tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "apollo",
|
||||
tracking_type: "items",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 50,
|
||||
});
|
||||
// 50 * 0.02 * 1_000_000 = 1_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(1_000_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for duration tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "e2b",
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 1_000_000,
|
||||
});
|
||||
// 1_000_000 * 0.000014 * 1_000_000 = 14_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(14_000_000);
|
||||
});
|
||||
});
|
||||
|
||||
describe("trackingValue", () => {
|
||||
it("returns formatted microdollars for cost_usd", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 1_000_000,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("$1.0000");
|
||||
});
|
||||
|
||||
it("returns formatted token count for tokens", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1.0K tokens");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 120,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0m");
|
||||
});
|
||||
|
||||
it("returns run count for per_run (default tracking)", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: null,
|
||||
request_count: 42,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("42 runs");
|
||||
});
|
||||
|
||||
it("returns formatted character count for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "characters",
|
||||
total_tracking_amount: 2500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.5K chars");
|
||||
});
|
||||
|
||||
it("returns formatted item count for items tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "items",
|
||||
total_tracking_amount: 1234,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1,234 items");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 7200,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0h");
|
||||
});
|
||||
|
||||
it("returns formatted duration for walltime_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "walltime_seconds",
|
||||
total_duration_seconds: 45,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("45.0s");
|
||||
});
|
||||
});
|
||||
|
||||
describe("toLocalInput", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toLocalInput("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid ISO", () => {
|
||||
expect(toLocalInput("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts UTC ISO to local datetime-local format", () => {
|
||||
const result = toLocalInput("2026-01-15T12:30:00Z");
|
||||
// Format should be YYYY-MM-DDTHH:mm
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$/);
|
||||
});
|
||||
});
|
||||
|
||||
describe("toUtcIso", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toUtcIso("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid local time", () => {
|
||||
expect(toUtcIso("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts local datetime-local to ISO string", () => {
|
||||
const result = toUtcIso("2026-01-15T12:30");
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,45 @@
|
||||
import {
|
||||
getV2GetPlatformCostDashboard,
|
||||
getV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
// Backend expects ISO datetime strings. The generated client's URL builder
|
||||
// calls .toString() on values, which for Date objects produces the human
|
||||
// "Tue Mar 31 2026 22:00:00 GMT+0000 (Coordinated Universal Time)" format
|
||||
// that FastAPI rejects with 422. We already pass UTC ISO from the URL, so
|
||||
// forward the raw strings through the `as unknown as Date` cast to match
|
||||
// the generated typing without triggering Date.toString().
|
||||
export async function getPlatformCostDashboard(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostDashboard({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
export async function getPlatformCostLogs(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostLogs({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
page: params?.page,
|
||||
page_size: params?.page_size,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow";
|
||||
import type { Pagination } from "@/app/api/__generated__/models/pagination";
|
||||
import { formatDuration, formatMicrodollars, formatTokens } from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
function formatLogDate(value: unknown): string {
|
||||
if (value instanceof Date) return value.toLocaleString();
|
||||
if (typeof value === "string" || typeof value === "number")
|
||||
return new Date(value).toLocaleString();
|
||||
return "-";
|
||||
}
|
||||
|
||||
interface Props {
|
||||
logs: CostLogRow[];
|
||||
pagination: Pagination | null;
|
||||
onPageChange: (page: number) => void;
|
||||
}
|
||||
|
||||
function LogsTable({ logs, pagination, onPageChange }: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Time
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Block
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Model
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Cost
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Duration
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Session
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{logs.map((log) => (
|
||||
<tr key={log.id} className="border-b hover:bg-muted">
|
||||
<td className="whitespace-nowrap px-3 py-2 text-xs">
|
||||
{formatLogDate(log.created_at)}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
{log.email ||
|
||||
(log.user_id ? String(log.user_id).slice(0, 8) : "-")}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs font-medium">
|
||||
{log.block_name}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.provider}</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
<TrackingBadge trackingType={log.tracking_type} />
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.model || "-"}</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.cost_microdollars != null
|
||||
? formatMicrodollars(Number(log.cost_microdollars))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.input_tokens != null || log.output_tokens != null
|
||||
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.duration != null
|
||||
? formatDuration(Number(log.duration))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs text-muted-foreground">
|
||||
{log.graph_exec_id
|
||||
? String(log.graph_exec_id).slice(0, 8)
|
||||
: "-"}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{logs.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={10}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No logs found
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{pagination && pagination.total_pages > 1 && (
|
||||
<div className="flex items-center justify-between px-4">
|
||||
<span className="text-sm text-muted-foreground">
|
||||
Page {pagination.current_page} of {pagination.total_pages} (
|
||||
{pagination.total_items} total)
|
||||
</span>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
disabled={pagination.current_page <= 1}
|
||||
onClick={() => onPageChange(pagination.current_page - 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<button
|
||||
disabled={pagination.current_page >= pagination.total_pages}
|
||||
onClick={() => onPageChange(pagination.current_page + 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { LogsTable };
|
||||
@@ -0,0 +1,233 @@
|
||||
"use client";
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { formatMicrodollars } from "../helpers";
|
||||
import { SummaryCard } from "./SummaryCard";
|
||||
import { ProviderTable } from "./ProviderTable";
|
||||
import { UserTable } from "./UserTable";
|
||||
import { LogsTable } from "./LogsTable";
|
||||
import { usePlatformCostContent } from "./usePlatformCostContent";
|
||||
|
||||
interface Props {
|
||||
searchParams: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
}
|
||||
|
||||
function PlatformCostContent({ searchParams }: Props) {
|
||||
const {
|
||||
dashboard,
|
||||
logs,
|
||||
pagination,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
} = usePlatformCostContent(searchParams);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="start-date" className="text-sm text-muted-foreground">
|
||||
Start Date <span className="text-xs">(local time)</span>
|
||||
</label>
|
||||
<input
|
||||
id="start-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={startInput}
|
||||
onChange={(e) => setStartInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="end-date" className="text-sm text-muted-foreground">
|
||||
End Date <span className="text-xs">(local time)</span>
|
||||
</label>
|
||||
<input
|
||||
id="end-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={endInput}
|
||||
onChange={(e) => setEndInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="provider-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
Provider
|
||||
</label>
|
||||
<input
|
||||
id="provider-filter"
|
||||
type="text"
|
||||
placeholder="e.g. openai"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={providerInput}
|
||||
onChange={(e) => setProviderInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="user-id-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
User ID
|
||||
</label>
|
||||
<input
|
||||
id="user-id-filter"
|
||||
type="text"
|
||||
placeholder="Filter by user"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={userInput}
|
||||
onChange={(e) => setUserInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleFilter}
|
||||
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Apply
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setStartInput("");
|
||||
setEndInput("");
|
||||
setProviderInput("");
|
||||
setUserInput("");
|
||||
updateUrl({
|
||||
start: "",
|
||||
end: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
page: "1",
|
||||
});
|
||||
}}
|
||||
className="rounded border px-4 py-1.5 text-sm hover:bg-muted"
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<Alert variant="error">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-8 w-48 rounded" />
|
||||
<Skeleton className="h-64 rounded-lg" />
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{dashboard && (
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
role="tablist"
|
||||
aria-label="Cost view tabs"
|
||||
className="flex gap-2 border-b"
|
||||
>
|
||||
{["overview", "by-user", "logs"].map((t) => (
|
||||
<button
|
||||
key={t}
|
||||
id={`tab-${t}`}
|
||||
role="tab"
|
||||
aria-selected={tab === t}
|
||||
aria-controls={`tabpanel-${t}`}
|
||||
onClick={() => updateUrl({ tab: t, page: "1" })}
|
||||
className={`px-4 py-2 text-sm font-medium ${tab === t ? "border-b-2 border-primary text-primary" : "text-muted-foreground hover:text-foreground"}`}
|
||||
>
|
||||
{t === "overview"
|
||||
? "By Provider"
|
||||
: t === "by-user"
|
||||
? "By User"
|
||||
: "Raw Logs"}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{tab === "overview" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-overview"
|
||||
aria-labelledby="tab-overview"
|
||||
>
|
||||
<ProviderTable
|
||||
data={dashboard.by_provider}
|
||||
rateOverrides={rateOverrides}
|
||||
onRateOverride={handleRateOverride}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{tab === "by-user" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-by-user"
|
||||
aria-labelledby="tab-by-user"
|
||||
>
|
||||
<UserTable data={dashboard.by_user} />
|
||||
</div>
|
||||
)}
|
||||
{tab === "logs" && (
|
||||
<div role="tabpanel" id="tabpanel-logs" aria-labelledby="tab-logs">
|
||||
<LogsTable
|
||||
logs={logs}
|
||||
pagination={pagination}
|
||||
onPageChange={(p) => updateUrl({ page: p.toString() })}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { PlatformCostContent };
|
||||
@@ -0,0 +1,131 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
defaultRateFor,
|
||||
estimateCostForRow,
|
||||
formatMicrodollars,
|
||||
rateKey,
|
||||
rateUnitLabel,
|
||||
trackingValue,
|
||||
} from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
interface Props {
|
||||
data: ProviderCostSummary[];
|
||||
rateOverrides: Record<string, number>;
|
||||
onRateOverride: (key: string, val: number | null) => void;
|
||||
}
|
||||
|
||||
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Usage
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Est. Cost
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Per-session only"
|
||||
>
|
||||
Rate <span className="text-[10px] font-normal">(unsaved)</span>
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
const tt = row.tracking_type || "per_run";
|
||||
// For cost_usd rows the provider reports USD directly so rate
|
||||
// input doesn't apply; otherwise show an editable input.
|
||||
const showRateInput = tt !== "cost_usd";
|
||||
const key = rateKey(row.provider, tt);
|
||||
const fallback = defaultRateFor(row.provider, tt);
|
||||
const currentRate = rateOverrides[key] ?? fallback;
|
||||
return (
|
||||
<tr key={key} className="border-b hover:bg-muted">
|
||||
<td className="px-4 py-3 font-medium">{row.provider}</td>
|
||||
<td className="px-4 py-3">
|
||||
<TrackingBadge trackingType={row.tracking_type} />
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{est !== null ? (
|
||||
formatMicrodollars(est)
|
||||
) : (
|
||||
<span className="text-muted-foreground">-</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-4 py-2 text-right">
|
||||
{showRateInput ? (
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<input
|
||||
type="number"
|
||||
step="0.0001"
|
||||
min="0"
|
||||
aria-label={`Rate for ${row.provider} (${tt})`}
|
||||
className="w-24 rounded border px-2 py-1 text-right text-xs"
|
||||
placeholder={fallback !== null ? String(fallback) : "0"}
|
||||
value={currentRate ?? ""}
|
||||
onChange={(e) => {
|
||||
const val = parseFloat(e.target.value);
|
||||
if (!isNaN(val)) onRateOverride(key, val);
|
||||
else if (e.target.value === "")
|
||||
onRateOverride(key, null);
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
className="text-[10px] text-muted-foreground"
|
||||
title={rateUnitLabel(tt)}
|
||||
>
|
||||
{rateUnitLabel(tt)}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-xs text-muted-foreground">auto</span>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={7}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ProviderTable };
|
||||
@@ -0,0 +1,19 @@
|
||||
interface Props {
|
||||
label: string;
|
||||
value: string;
|
||||
subtitle?: string;
|
||||
}
|
||||
|
||||
function SummaryCard({ label, value, subtitle }: Props) {
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<div className="text-sm text-muted-foreground">{label}</div>
|
||||
<div className="text-2xl font-bold">{value}</div>
|
||||
{subtitle && (
|
||||
<div className="mt-1 text-xs text-muted-foreground">{subtitle}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { SummaryCard };
|
||||
@@ -0,0 +1,25 @@
|
||||
function TrackingBadge({
|
||||
trackingType,
|
||||
}: {
|
||||
trackingType: string | null | undefined;
|
||||
}) {
|
||||
const colors: Record<string, string> = {
|
||||
cost_usd: "bg-green-500/10 text-green-700",
|
||||
tokens: "bg-blue-500/10 text-blue-700",
|
||||
characters: "bg-purple-500/10 text-purple-700",
|
||||
sandbox_seconds: "bg-orange-500/10 text-orange-700",
|
||||
walltime_seconds: "bg-orange-500/10 text-orange-700",
|
||||
items: "bg-pink-500/10 text-pink-700",
|
||||
per_run: "bg-muted text-muted-foreground",
|
||||
};
|
||||
const label = trackingType || "per_run";
|
||||
return (
|
||||
<span
|
||||
className={`inline-block rounded px-1.5 py-0.5 text-[10px] font-medium ${colors[label] || colors.per_run}`}
|
||||
>
|
||||
{label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export { TrackingBadge };
|
||||
@@ -0,0 +1,75 @@
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import { formatMicrodollars, formatTokens } from "../helpers";
|
||||
|
||||
interface Props {
|
||||
data: PlatformCostDashboard["by_user"];
|
||||
}
|
||||
|
||||
function UserTable({ data }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Input Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Output Tokens
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row, idx) => (
|
||||
<tr
|
||||
key={row.user_id ?? `unknown-${idx}`}
|
||||
className="border-b hover:bg-muted"
|
||||
>
|
||||
<td className="px-4 py-3">
|
||||
<div className="font-medium">{row.email || "Unknown"}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{row.user_id}
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_input_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={5}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { UserTable };
|
||||
@@ -0,0 +1,136 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
useGetV2GetPlatformCostDashboard,
|
||||
useGetV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers";
|
||||
|
||||
interface InitialSearchParams {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
}
|
||||
|
||||
export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
const router = useRouter();
|
||||
const urlParams = useSearchParams();
|
||||
|
||||
const tab = urlParams.get("tab") || searchParams.tab || "overview";
|
||||
const page = parseInt(urlParams.get("page") || searchParams.page || "1", 10);
|
||||
const startDate = urlParams.get("start") || searchParams.start || "";
|
||||
const endDate = urlParams.get("end") || searchParams.end || "";
|
||||
const providerFilter =
|
||||
urlParams.get("provider") || searchParams.provider || "";
|
||||
const userFilter = urlParams.get("user_id") || searchParams.user_id || "";
|
||||
|
||||
const [startInput, setStartInput] = useState(toLocalInput(startDate));
|
||||
const [endInput, setEndInput] = useState(toLocalInput(endDate));
|
||||
const [providerInput, setProviderInput] = useState(providerFilter);
|
||||
const [userInput, setUserInput] = useState(userFilter);
|
||||
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
|
||||
{},
|
||||
);
|
||||
|
||||
// Pass ISO date strings through `as unknown as Date` so Orval's URL builder
|
||||
// forwards them as-is. Date.toString() produces a format FastAPI rejects;
|
||||
// strings pass through .toString() unchanged.
|
||||
const filterParams = {
|
||||
start: (startDate || undefined) as unknown as Date | undefined,
|
||||
end: (endDate || undefined) as unknown as Date | undefined,
|
||||
provider: providerFilter || undefined,
|
||||
user_id: userFilter || undefined,
|
||||
};
|
||||
|
||||
const {
|
||||
data: dashboard,
|
||||
isLoading: dashLoading,
|
||||
error: dashError,
|
||||
} = useGetV2GetPlatformCostDashboard(filterParams, {
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const {
|
||||
data: logsResponse,
|
||||
isLoading: logsLoading,
|
||||
error: logsError,
|
||||
} = useGetV2GetPlatformCostLogs(
|
||||
{ ...filterParams, page, page_size: 50 },
|
||||
{ query: { select: okData } },
|
||||
);
|
||||
|
||||
const loading = dashLoading || logsLoading;
|
||||
const error = dashError
|
||||
? dashError instanceof Error
|
||||
? dashError.message
|
||||
: "Failed to load dashboard"
|
||||
: logsError
|
||||
? logsError instanceof Error
|
||||
? logsError.message
|
||||
: "Failed to load logs"
|
||||
: null;
|
||||
|
||||
function updateUrl(overrides: Record<string, string>) {
|
||||
const params = new URLSearchParams(urlParams.toString());
|
||||
for (const [k, v] of Object.entries(overrides)) {
|
||||
if (v) params.set(k, v);
|
||||
else params.delete(k);
|
||||
}
|
||||
router.push(`/admin/platform-costs?${params.toString()}`);
|
||||
}
|
||||
|
||||
function handleFilter() {
|
||||
updateUrl({
|
||||
start: toUtcIso(startInput),
|
||||
end: toUtcIso(endInput),
|
||||
provider: providerInput,
|
||||
user_id: userInput,
|
||||
page: "1",
|
||||
});
|
||||
}
|
||||
|
||||
function handleRateOverride(key: string, val: number | null) {
|
||||
setRateOverrides((prev) => {
|
||||
if (val === null) {
|
||||
const { [key]: _, ...rest } = prev;
|
||||
return rest;
|
||||
}
|
||||
return { ...prev, [key]: val };
|
||||
});
|
||||
}
|
||||
|
||||
const totalEstimatedCost =
|
||||
dashboard?.by_provider.reduce((sum, row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
return sum + (est ?? 0);
|
||||
}, 0) ?? 0;
|
||||
|
||||
return {
|
||||
dashboard: dashboard ?? null,
|
||||
logs: logsResponse?.logs ?? [],
|
||||
pagination: logsResponse?.pagination ?? null,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
page,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
|
||||
const MICRODOLLARS_PER_USD = 1_000_000;
|
||||
|
||||
// Per-request cost estimates (USD) for providers billed per API call.
|
||||
export const DEFAULT_COST_PER_RUN: Record<string, number> = {
|
||||
google_maps: 0.032, // $0.032/request - Google Maps Places API
|
||||
ideogram: 0.08, // $0.08/image - Ideogram standard generation
|
||||
nvidia: 0.0, // Free tier - NVIDIA NIM deepfake detection
|
||||
screenshotone: 0.01, // ~$0.01/screenshot - ScreenshotOne starter
|
||||
zerobounce: 0.008, // $0.008/validation - ZeroBounce
|
||||
mem0: 0.01, // ~$0.01/request - Mem0
|
||||
openweathermap: 0.0, // Free tier
|
||||
webshare_proxy: 0.0, // Flat subscription
|
||||
enrichlayer: 0.1, // ~$0.10/profile lookup
|
||||
jina: 0.0, // Free tier
|
||||
};
|
||||
|
||||
export const DEFAULT_COST_PER_1K_TOKENS: Record<string, number> = {
|
||||
openai: 0.005,
|
||||
anthropic: 0.008,
|
||||
groq: 0.0003,
|
||||
ollama: 0.0,
|
||||
aiml_api: 0.005,
|
||||
llama_api: 0.003,
|
||||
v0: 0.005,
|
||||
};
|
||||
|
||||
// Per-character rates (USD / 1K characters) for TTS providers.
|
||||
export const DEFAULT_COST_PER_1K_CHARS: Record<string, number> = {
|
||||
unreal_speech: 0.008, // ~$8/1M chars on Starter
|
||||
elevenlabs: 0.18, // ~$0.18/1K chars on Starter
|
||||
d_id: 0.04, // ~$0.04/1K chars estimated
|
||||
};
|
||||
|
||||
// Per-item rates (USD / item) for item-count billed APIs.
|
||||
export const DEFAULT_COST_PER_ITEM: Record<string, number> = {
|
||||
google_maps: 0.017, // avg of $0.032 nearby + ~$0.015 detail enrich
|
||||
apollo: 0.02, // ~$0.02/contact on low-volume tiers
|
||||
smartlead: 0.001, // ~$0.001/lead added
|
||||
};
|
||||
|
||||
// Per-second rates (USD / second) for duration-billed providers.
|
||||
export const DEFAULT_COST_PER_SECOND: Record<string, number> = {
|
||||
e2b: 0.000014, // $0.000014/sec (2-core sandbox)
|
||||
fal: 0.0005, // varies by model, conservative
|
||||
replicate: 0.001, // varies by hardware
|
||||
revid: 0.01, // per-second of video
|
||||
};
|
||||
|
||||
export function toDateOrUndefined(val?: string): Date | undefined {
|
||||
if (!val) return undefined;
|
||||
const d = new Date(val);
|
||||
return isNaN(d.getTime()) ? undefined : d;
|
||||
}
|
||||
|
||||
export function formatMicrodollars(microdollars: number) {
|
||||
return `$${(microdollars / MICRODOLLARS_PER_USD).toFixed(4)}`;
|
||||
}
|
||||
|
||||
export function formatTokens(tokens: number) {
|
||||
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
|
||||
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`;
|
||||
return tokens.toString();
|
||||
}
|
||||
|
||||
export function formatDuration(seconds: number) {
|
||||
if (seconds >= 3600) return `${(seconds / 3600).toFixed(1)}h`;
|
||||
if (seconds >= 60) return `${(seconds / 60).toFixed(1)}m`;
|
||||
return `${seconds.toFixed(1)}s`;
|
||||
}
|
||||
|
||||
// Unit label for each tracking type — what the rate input represents.
|
||||
export function rateUnitLabel(trackingType: string | null | undefined): string {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return "$/1K tokens";
|
||||
case "characters":
|
||||
return "$/1K chars";
|
||||
case "items":
|
||||
return "$/item";
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return "$/second";
|
||||
case "per_run":
|
||||
return "$/run";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
// Default rate for a (provider, tracking_type) pair.
|
||||
export function defaultRateFor(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): number | null {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return DEFAULT_COST_PER_1K_TOKENS[provider] ?? null;
|
||||
case "characters":
|
||||
return DEFAULT_COST_PER_1K_CHARS[provider] ?? null;
|
||||
case "items":
|
||||
return DEFAULT_COST_PER_ITEM[provider] ?? null;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return DEFAULT_COST_PER_SECOND[provider] ?? null;
|
||||
case "per_run":
|
||||
return DEFAULT_COST_PER_RUN[provider] ?? null;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Overrides are keyed on `${provider}:${tracking_type}` since the same
|
||||
// provider can have multiple rows with different billing models.
|
||||
export function rateKey(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): string {
|
||||
return `${provider}:${trackingType ?? "per_run"}`;
|
||||
}
|
||||
|
||||
export function estimateCostForRow(
|
||||
row: ProviderCostSummary,
|
||||
rateOverrides: Record<string, number>,
|
||||
) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
|
||||
// Providers that report USD directly: use known cost.
|
||||
if (tt === "cost_usd") return row.total_cost_microdollars;
|
||||
|
||||
// Prefer the real USD the provider reported if any, but only for token paths
|
||||
// where OpenRouter piggybacks on the tokens row via x-total-cost.
|
||||
if (tt === "tokens" && row.total_cost_microdollars > 0) {
|
||||
return row.total_cost_microdollars;
|
||||
}
|
||||
|
||||
const rate =
|
||||
rateOverrides[rateKey(row.provider, tt)] ??
|
||||
defaultRateFor(row.provider, tt);
|
||||
if (rate === null || rate === undefined) return null;
|
||||
|
||||
// Compute the amount for this tracking type, then multiply by rate.
|
||||
let amount: number;
|
||||
switch (tt) {
|
||||
case "tokens":
|
||||
// Rate is per-1K tokens.
|
||||
amount = (row.total_input_tokens + row.total_output_tokens) / 1000;
|
||||
break;
|
||||
case "characters":
|
||||
// Rate is per-1K chars. trackingAmount aggregates char counts.
|
||||
amount = (row.total_tracking_amount || 0) / 1000;
|
||||
break;
|
||||
case "items":
|
||||
amount = row.total_tracking_amount || 0;
|
||||
break;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
amount = row.total_duration_seconds || 0;
|
||||
break;
|
||||
case "per_run":
|
||||
amount = row.request_count;
|
||||
break;
|
||||
default:
|
||||
return row.total_cost_microdollars > 0
|
||||
? row.total_cost_microdollars
|
||||
: null;
|
||||
}
|
||||
|
||||
return Math.round(rate * amount * MICRODOLLARS_PER_USD);
|
||||
}
|
||||
|
||||
export function trackingValue(row: ProviderCostSummary) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars);
|
||||
if (tt === "tokens") {
|
||||
const tokens = row.total_input_tokens + row.total_output_tokens;
|
||||
return `${formatTokens(tokens)} tokens`;
|
||||
}
|
||||
if (tt === "sandbox_seconds" || tt === "walltime_seconds")
|
||||
return formatDuration(row.total_duration_seconds || 0);
|
||||
if (tt === "characters")
|
||||
return `${formatTokens(Math.round(row.total_tracking_amount || 0))} chars`;
|
||||
if (tt === "items")
|
||||
return `${Math.round(row.total_tracking_amount || 0).toLocaleString()} items`;
|
||||
return `${row.request_count.toLocaleString()} runs`;
|
||||
}
|
||||
|
||||
// URL holds UTC ISO; datetime-local inputs need local "YYYY-MM-DDTHH:mm".
|
||||
export function toLocalInput(iso: string) {
|
||||
if (!iso) return "";
|
||||
const d = new Date(iso);
|
||||
if (isNaN(d.getTime())) return "";
|
||||
const pad = (n: number) => String(n).padStart(2, "0");
|
||||
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())}T${pad(d.getHours())}:${pad(d.getMinutes())}`;
|
||||
}
|
||||
|
||||
// datetime-local emits naive local time; convert to UTC ISO so the
|
||||
// backend filter window matches what the admin sees in their browser.
|
||||
export function toUtcIso(local: string) {
|
||||
if (!local) return "";
|
||||
const d = new Date(local);
|
||||
return isNaN(d.getTime()) ? "" : d.toISOString();
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import { PlatformCostContent } from "./components/PlatformCostContent";
|
||||
|
||||
type SearchParams = {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
|
||||
function PlatformCostDashboard({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: SearchParams;
|
||||
}) {
|
||||
return (
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">Platform Costs</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Track real API costs incurred by system credentials across providers
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Suspense
|
||||
key={JSON.stringify(searchParams)}
|
||||
fallback={
|
||||
<div className="py-10 text-center">Loading cost data...</div>
|
||||
}
|
||||
>
|
||||
<PlatformCostContent searchParams={searchParams} />
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function PlatformCostDashboardPage({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: Promise<SearchParams>;
|
||||
}) {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedDashboard = await withAdminAccess(PlatformCostDashboard);
|
||||
return <ProtectedDashboard searchParams={await searchParams} />;
|
||||
}
|
||||
@@ -7,6 +7,179 @@
|
||||
"version": "0.1"
|
||||
},
|
||||
"paths": {
|
||||
"/api/admin/platform-costs/dashboard": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "platform-cost", "admin"],
|
||||
"summary": "Get Platform Cost Dashboard",
|
||||
"operationId": "getV2Get platform cost dashboard",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "start",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Start"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "end",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "End"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Provider"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "user_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/PlatformCostDashboard"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/platform-costs/logs": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "platform-cost", "admin"],
|
||||
"summary": "Get Platform Cost Logs",
|
||||
"operationId": "getV2Get platform cost logs",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "start",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Start"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "end",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "End"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Provider"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "user_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 1,
|
||||
"title": "Page"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page_size",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 200,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Page Size"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/PlatformCostLogsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/analytics/log_raw_analytics": {
|
||||
"post": {
|
||||
"tags": ["analytics"],
|
||||
@@ -8733,6 +8906,61 @@
|
||||
],
|
||||
"title": "ContentType"
|
||||
},
|
||||
"CostLogRow": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"graph_exec_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Graph Exec Id"
|
||||
},
|
||||
"node_exec_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Node Exec Id"
|
||||
},
|
||||
"block_name": { "type": "string", "title": "Block Name" },
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
"tracking_type": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
},
|
||||
"cost_microdollars": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Cost Microdollars"
|
||||
},
|
||||
"input_tokens": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Input Tokens"
|
||||
},
|
||||
"output_tokens": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Output Tokens"
|
||||
},
|
||||
"duration": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Duration"
|
||||
},
|
||||
"model": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Model"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "block_name", "provider"],
|
||||
"title": "CostLogRow"
|
||||
},
|
||||
"CountResponse": {
|
||||
"properties": {
|
||||
"all_blocks": { "type": "integer", "title": "All Blocks" },
|
||||
@@ -11664,6 +11892,48 @@
|
||||
"title": "PendingHumanReviewModel",
|
||||
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
||||
},
|
||||
"PlatformCostDashboard": {
|
||||
"properties": {
|
||||
"by_provider": {
|
||||
"items": { "$ref": "#/components/schemas/ProviderCostSummary" },
|
||||
"type": "array",
|
||||
"title": "By Provider"
|
||||
},
|
||||
"by_user": {
|
||||
"items": { "$ref": "#/components/schemas/UserCostSummary" },
|
||||
"type": "array",
|
||||
"title": "By User"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_requests": { "type": "integer", "title": "Total Requests" },
|
||||
"total_users": { "type": "integer", "title": "Total Users" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"by_provider",
|
||||
"by_user",
|
||||
"total_cost_microdollars",
|
||||
"total_requests",
|
||||
"total_users"
|
||||
],
|
||||
"title": "PlatformCostDashboard"
|
||||
},
|
||||
"PlatformCostLogsResponse": {
|
||||
"properties": {
|
||||
"logs": {
|
||||
"items": { "$ref": "#/components/schemas/CostLogRow" },
|
||||
"type": "array",
|
||||
"title": "Logs"
|
||||
},
|
||||
"pagination": { "$ref": "#/components/schemas/Pagination" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["logs", "pagination"],
|
||||
"title": "PlatformCostLogsResponse"
|
||||
},
|
||||
"PostmarkBounceEnum": {
|
||||
"type": "integer",
|
||||
"enum": [
|
||||
@@ -12058,6 +12328,47 @@
|
||||
"title": "ProviderConstants",
|
||||
"description": "Model that exposes all provider names as a constant in the OpenAPI schema.\nThis is designed to be converted by Orval into a TypeScript constant."
|
||||
},
|
||||
"ProviderCostSummary": {
|
||||
"properties": {
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
"tracking_type": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_input_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Input Tokens"
|
||||
},
|
||||
"total_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens"
|
||||
},
|
||||
"total_duration_seconds": {
|
||||
"type": "number",
|
||||
"title": "Total Duration Seconds",
|
||||
"default": 0.0
|
||||
},
|
||||
"total_tracking_amount": {
|
||||
"type": "number",
|
||||
"title": "Total Tracking Amount",
|
||||
"default": 0.0
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"provider",
|
||||
"total_cost_microdollars",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"request_count"
|
||||
],
|
||||
"title": "ProviderCostSummary"
|
||||
},
|
||||
"ProviderEnumResponse": {
|
||||
"properties": {
|
||||
"provider": {
|
||||
@@ -14938,6 +15249,39 @@
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserCostSummary": {
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_input_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Input Tokens"
|
||||
},
|
||||
"total_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens"
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"total_cost_microdollars",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"request_count"
|
||||
],
|
||||
"title": "UserCostSummary"
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
BIN
test-screenshots/PR-12696/01-after-login.png
Normal file
|
After Width: | Height: | Size: 78 KiB |
BIN
test-screenshots/PR-12696/02-admin-platform-costs.png
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
test-screenshots/PR-12696/03-by-provider-table.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
test-screenshots/PR-12696/04-by-user-tab.png
Normal file
|
After Width: | Height: | Size: 86 KiB |
BIN
test-screenshots/PR-12696/05-by-user-rows.png
Normal file
|
After Width: | Height: | Size: 46 KiB |
BIN
test-screenshots/PR-12696/06-raw-logs-tab.png
Normal file
|
After Width: | Height: | Size: 124 KiB |
BIN
test-screenshots/PR-12696/07-provider-filter.png
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
test-screenshots/PR-12696/08-retest-dashboard.png
Normal file
|
After Width: | Height: | Size: 88 KiB |