mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/task-decomposition-copilot
This commit is contained in:
709
.claude/skills/orchestrate/SKILL.md
Normal file
709
.claude/skills/orchestrate/SKILL.md
Normal file
@@ -0,0 +1,709 @@
|
||||
---
|
||||
name: orchestrate
|
||||
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
|
||||
user-invocable: true
|
||||
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "6.0.0"
|
||||
---
|
||||
|
||||
# Orchestrate — Agent Fleet Supervisor
|
||||
|
||||
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
|
||||
|
||||
## Scripts
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
|
||||
STATE_FILE=~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
|
||||
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
|
||||
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
|
||||
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
|
||||
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
|
||||
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
|
||||
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
|
||||
| `status.sh` | Print fleet status + live pane commands |
|
||||
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
|
||||
| `classify-pane.sh WINDOW` | Classify one pane state |
|
||||
|
||||
## Supervision model
|
||||
|
||||
```
|
||||
Orchestrating Claude (this Claude session — IS the supervisor)
|
||||
└── Reads pane output, checks CI, intervenes with targeted guidance
|
||||
run-loop.sh (separate tmux window, every 30s)
|
||||
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
|
||||
```
|
||||
|
||||
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
|
||||
|
||||
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
|
||||
|
||||
## Checkpoint protocol
|
||||
|
||||
Agents output checkpoints as they complete each required step:
|
||||
|
||||
```
|
||||
CHECKPOINT:<step-name>
|
||||
```
|
||||
|
||||
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
|
||||
|
||||
## Worktree lifecycle
|
||||
|
||||
```text
|
||||
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
|
||||
↓
|
||||
CHECKPOINT:<step> (as steps complete)
|
||||
↓
|
||||
ORCHESTRATOR:DONE
|
||||
↓
|
||||
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
|
||||
↓
|
||||
state → "done", notify, window KEPT OPEN
|
||||
↓
|
||||
user/orchestrator explicitly requests recycle
|
||||
↓
|
||||
recycle-agent.sh → spare/N (free again)
|
||||
```
|
||||
|
||||
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
|
||||
|
||||
**To resume a done or crashed session:**
|
||||
```bash
|
||||
# Resume by stored session ID (preferred — exact session, full context)
|
||||
claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
|
||||
# Or resume most recent session in that worktree directory
|
||||
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
|
||||
```
|
||||
|
||||
**To manually recycle when ready:**
|
||||
```bash
|
||||
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
|
||||
# Then update state:
|
||||
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## State file (`~/.claude/orchestrator-state.json`)
|
||||
|
||||
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`).
|
||||
|
||||
```json
|
||||
{
|
||||
"active": true,
|
||||
"tmux_session": "autogpt1",
|
||||
"idle_threshold_seconds": 300,
|
||||
"loop_window": "autogpt1:5",
|
||||
"repo": "Significant-Gravitas/AutoGPT",
|
||||
"discord_webhook": "https://discord.com/api/webhooks/...",
|
||||
"last_poll_at": 0,
|
||||
"agents": [
|
||||
{
|
||||
"window": "autogpt1:3",
|
||||
"worktree": "AutoGPT6",
|
||||
"worktree_path": "/path/to/AutoGPT6",
|
||||
"spare_branch": "spare/6",
|
||||
"branch": "feat/my-feature",
|
||||
"objective": "Implement X and open a PR",
|
||||
"pr_number": "12345",
|
||||
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"steps": ["pr-address", "pr-test"],
|
||||
"checkpoints": ["pr-address"],
|
||||
"state": "running",
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": 0,
|
||||
"spawned_at": 0,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Top-level optional fields:
|
||||
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
|
||||
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
|
||||
|
||||
Per-agent fields:
|
||||
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
|
||||
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
|
||||
|
||||
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
|
||||
|
||||
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
|
||||
|
||||
## Serial /pr-test rule
|
||||
|
||||
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
|
||||
|
||||
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
|
||||
|
||||
You (the orchestrating Claude) own the test queue:
|
||||
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
|
||||
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
|
||||
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
|
||||
4. Feed results back to the relevant agent via `tmux send-keys`:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
5. Wait for CI to confirm green before marking the agent done.
|
||||
|
||||
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
|
||||
|
||||
## Session restore (tested and confirmed)
|
||||
|
||||
Agent sessions are saved to disk. To restore a closed or crashed session:
|
||||
|
||||
```bash
|
||||
# If session_id is in state (preferred):
|
||||
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
|
||||
|
||||
# If no session_id (use --continue for most recent session in that directory):
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
|
||||
```
|
||||
|
||||
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
|
||||
|
||||
```bash
|
||||
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
|
||||
'(.agents[] | select(.window == $old)).window = $new' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## Intent → action mapping
|
||||
|
||||
Match the user's message to one of these intents:
|
||||
|
||||
| The user says something like… | What to do |
|
||||
|---|---|
|
||||
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
|
||||
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
|
||||
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
|
||||
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
|
||||
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
|
||||
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
|
||||
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
|
||||
|
||||
When the intent is ambiguous, show capacity first and ask what tasks to run.
|
||||
|
||||
## Spawning agents
|
||||
|
||||
### 1. Resolve tmux session
|
||||
|
||||
```bash
|
||||
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
|
||||
```
|
||||
|
||||
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
|
||||
|
||||
### 2. Show available capacity
|
||||
|
||||
```bash
|
||||
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
|
||||
```
|
||||
|
||||
### 3. Collect tasks from the user
|
||||
|
||||
For each task, gather:
|
||||
- **objective** — what to do (e.g. "implement feature X and open a PR")
|
||||
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
|
||||
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
|
||||
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
|
||||
|
||||
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
|
||||
|
||||
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
|
||||
|
||||
### 4. Spawn one agent per task
|
||||
|
||||
```bash
|
||||
# Get ordered list of spare worktrees
|
||||
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
|
||||
|
||||
# For each task, take the next spare line:
|
||||
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
|
||||
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
|
||||
|
||||
# With PR number and required steps:
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
|
||||
|
||||
# Without PR (new work):
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
|
||||
```
|
||||
|
||||
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
|
||||
|
||||
```bash
|
||||
# Derive repo from git remote (used by verify-complete.sh + supervisor)
|
||||
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
|
||||
jq -n \
|
||||
--arg session "$SESSION" \
|
||||
--arg repo "$REPO" \
|
||||
--argjson threshold 300 \
|
||||
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
|
||||
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
|
||||
> ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
Optionally add a Discord webhook for completion notifications:
|
||||
```bash
|
||||
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
|
||||
|
||||
### 5. Start the mechanical babysitter
|
||||
|
||||
```bash
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
### 6. Begin supervising directly in this conversation
|
||||
|
||||
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
|
||||
|
||||
## Adding an agent
|
||||
|
||||
Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user.
|
||||
|
||||
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
|
||||
|
||||
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
|
||||
|
||||
### Poll loop mechanism
|
||||
|
||||
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
|
||||
|
||||
1. Start each poll with `run_in_background: true` + a sleep before the work:
|
||||
```bash
|
||||
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
|
||||
# + similar for each active window
|
||||
```
|
||||
2. When the background job notifies you, read the pane output and take action.
|
||||
3. Immediately schedule the next background poll — this keeps the loop alive.
|
||||
4. Stop scheduling when all agents are done/escalated.
|
||||
|
||||
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
|
||||
|
||||
### Each poll: what to check
|
||||
|
||||
```bash
|
||||
# 1. Read state
|
||||
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
|
||||
|
||||
# 2. For each running/stuck/idle agent, capture pane
|
||||
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
|
||||
```
|
||||
|
||||
For each agent, decide:
|
||||
|
||||
| What you see | Action |
|
||||
|---|---|
|
||||
| Spinner / tools running | Do nothing — agent is working |
|
||||
| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
|
||||
| Stuck in error loop | Send targeted fix with exact error + solution |
|
||||
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
|
||||
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
|
||||
| GitHub abuse rate limit error | Nudge: "Wait 60 seconds then continue posting replies with sleep 3 between each" |
|
||||
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
|
||||
| `ORCHESTRATOR:DONE` in output | Query GraphQL for actual unresolved count. If >0, re-brief. If 0, run `verify-complete.sh` |
|
||||
|
||||
**Poll all windows from state, not from memory.** Before each poll, run:
|
||||
```bash
|
||||
jq -r '.agents[] | select(.state | test("running|idle|stuck|waiting_approval|pending_evaluation")) | .window' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
and capture every window listed. If you manually added a window outside spawn-agent.sh, ensure it's in the state file first.
|
||||
|
||||
### RUNNING count includes waiting_approval agents
|
||||
|
||||
The `RUNNING` count from run-loop.sh includes agents in `waiting_approval` state (they match the regex `running|stuck|waiting_approval|idle`). This means a fleet that is only `waiting_approval` still shows RUNNING > 0 in the log — it does **not** mean agents are actively working.
|
||||
|
||||
When you see `RUNNING > 0` in the run-loop log but suspect agents are actually blocked, check state directly:
|
||||
```bash
|
||||
jq '.agents[] | {window, state, worktree}' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
A count of `running=1 waiting=1` in the log actually means one agent is waiting for approval — the orchestrator should check and approve, not wait.
|
||||
|
||||
### State file staleness recovery
|
||||
|
||||
The state file is written by scripts but can drift from reality when windows are closed, sessions expire, or the orchestrator restarts across conversations.
|
||||
|
||||
**Signs of stale state:**
|
||||
- `loop_window` points to a window that no longer exists in the tmux session
|
||||
- An agent's `state` is `running` but tmux window is closed or shows a shell prompt (not claude)
|
||||
- `last_seen_at` is hours old but state still says `running`
|
||||
|
||||
**Recovery steps:**
|
||||
|
||||
1. **Verify actual tmux windows:**
|
||||
```bash
|
||||
tmux list-windows -t SESSION -F '#{window_index}: #{window_name} (#{pane_current_command})'
|
||||
```
|
||||
|
||||
2. **Cross-reference with state file:**
|
||||
```bash
|
||||
jq -r '.agents[] | "\(.window) \(.state) \(.worktree)"' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
3. **Fix stale entries:**
|
||||
```bash
|
||||
# Agent window closed — mark idle so run-loop.sh will restart it
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window==$w)).state = "idle"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
# loop_window gone — kill the stale reference, then restart run-loop.sh
|
||||
jq '.loop_window = null' ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
4. **After any state repair, re-run `status.sh` to confirm coherence before resuming supervision.**
|
||||
|
||||
### Strict ORCHESTRATOR:DONE gate
|
||||
|
||||
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
|
||||
|
||||
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=~/.claude/orchestrator/scripts
|
||||
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
|
||||
```
|
||||
|
||||
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
|
||||
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
|
||||
|
||||
### Re-brief a stalled agent
|
||||
|
||||
**Before sending any nudge, verify the pane is at an idle ❯ prompt.** Sending text into a still-processing pane produces stuck `[Pasted text +N lines]` that the agent never sees.
|
||||
|
||||
Check:
|
||||
```bash
|
||||
tmux capture-pane -t SESSION:WIN -p 2>/dev/null | tail -5
|
||||
```
|
||||
If the last line shows a spinner (✳✽✢✶·), `Running…`, or no `❯` — wait 10–15s and check again before sending.
|
||||
|
||||
```bash
|
||||
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
|
||||
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
|
||||
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
|
||||
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
|
||||
|
||||
## Self-recovery protocol (agents)
|
||||
|
||||
spawn-agent.sh automatically includes this instruction in every objective:
|
||||
|
||||
> If your context compacts and you lose track of what to do, run:
|
||||
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
|
||||
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
|
||||
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
|
||||
|
||||
## Passing images and screenshots to agents
|
||||
|
||||
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
|
||||
|
||||
1. **Save the image to a temp file** with a stable path:
|
||||
```bash
|
||||
# If the user drags in a screenshot or you receive a file path:
|
||||
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
|
||||
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
|
||||
```
|
||||
|
||||
2. **Reference the path in the objective string**:
|
||||
```bash
|
||||
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
|
||||
```
|
||||
|
||||
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
|
||||
|
||||
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator final evaluation (YOU decide, not the script)
|
||||
|
||||
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
|
||||
|
||||
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
|
||||
|
||||
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
|
||||
|
||||
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
|
||||
|
||||
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
|
||||
```
|
||||
- [ ] /pr-test PR #12636 — fix copilot retry logic
|
||||
- [ ] /pr-test PR #12699 — builder chat panel
|
||||
```
|
||||
Run one at a time. Check off as you go.
|
||||
|
||||
```
|
||||
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
|
||||
```
|
||||
|
||||
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
|
||||
|
||||
```
|
||||
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
|
||||
Context: This PR implements <objective from state file>. Key files: <list>.
|
||||
Please verify: <specific behaviors to check>.
|
||||
```
|
||||
|
||||
Only one `/pr-test` at a time — they share ports and DB.
|
||||
|
||||
### /pr-test result evaluation
|
||||
|
||||
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
|
||||
|
||||
| `/pr-test` result | Action |
|
||||
|---|---|
|
||||
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
|
||||
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
|
||||
| Any headline scenario **FAIL** | Re-brief the agent immediately |
|
||||
|
||||
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
|
||||
|
||||
**When any headline scenario is PARTIAL or FAIL:**
|
||||
|
||||
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
|
||||
2. Re-brief the agent with the specific scenario that failed and what was missing:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
3. Set state back to `running`:
|
||||
```bash
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
|
||||
|
||||
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
|
||||
|
||||
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
|
||||
|
||||
### 2. Do your own evaluation
|
||||
|
||||
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
|
||||
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
|
||||
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
|
||||
4. **Check the PR description** — title, summary, test plan complete?
|
||||
|
||||
### 3. Decide
|
||||
|
||||
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
|
||||
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
|
||||
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
|
||||
|
||||
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
|
||||
|
||||
```bash
|
||||
# Mark done after your positive evaluation:
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## When to stop the fleet
|
||||
|
||||
Stop the fleet (`active = false`) when **all** of the following are true:
|
||||
|
||||
| Check | How to verify |
|
||||
|---|---|
|
||||
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
|
||||
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
|
||||
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
|
||||
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
|
||||
| No agents are `escalated` without human review | If any are escalated, surface to user first |
|
||||
|
||||
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
|
||||
|
||||
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
|
||||
|
||||
```bash
|
||||
# Graceful stop
|
||||
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
|
||||
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
|
||||
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
|
||||
```
|
||||
|
||||
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
|
||||
|
||||
## tmux send-keys pattern
|
||||
|
||||
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
|
||||
|
||||
```bash
|
||||
# CORRECT — text then Enter separately
|
||||
tmux send-keys -t "$WINDOW" "your long message here"
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# WRONG — Enter may fire before text is buffered
|
||||
tmux send-keys -t "$WINDOW" "your long message here" Enter
|
||||
```
|
||||
|
||||
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
|
||||
|
||||
---
|
||||
|
||||
## Protected worktrees
|
||||
|
||||
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
|
||||
|
||||
| Worktree | Protected branch | Why |
|
||||
|---|---|---|
|
||||
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
|
||||
|
||||
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
|
||||
|
||||
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
|
||||
|
||||
---
|
||||
|
||||
## Thread resolution integrity (critical)
|
||||
|
||||
**Agents MUST NOT resolve review threads via GraphQL unless a real code fix has been committed and pushed first.**
|
||||
|
||||
This is the most common failure mode: agents call `resolveReviewThread` to make unresolved counts drop without actually fixing anything. This produces a false "done" signal that gets past verify-complete.sh.
|
||||
|
||||
**The only valid resolution sequence:**
|
||||
1. Read the thread and understand what it's asking
|
||||
2. Make the actual code change
|
||||
3. `git commit` and `git push`
|
||||
4. Reply to the thread with the commit SHA (e.g. "Fixed in `abc1234`")
|
||||
5. THEN call `resolveReviewThread`
|
||||
|
||||
**The supervisor must verify actual thread counts via GraphQL** — never trust an agent's claim of "0 unresolved." After any agent's ORCHESTRATOR:DONE, always run:
|
||||
|
||||
```bash
|
||||
# Step 1: get total count
|
||||
TOTAL=$(gh api graphql -f query='{ repository(owner: "OWNER", name: "REPO") { pullRequest(number: PR) { reviewThreads { totalCount } } } }' \
|
||||
| jq '.data.repository.pullRequest.reviewThreads.totalCount')
|
||||
echo "Total threads: $TOTAL"
|
||||
|
||||
# Step 2: paginate all pages and count unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="{ repository(owner: \"OWNER\", name: \"REPO\") { pullRequest(number: PR) { reviewThreads(first: 100${AFTER}) { pageInfo { hasNextPage endCursor } nodes { isResolved } } } } }")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved: $UNRESOLVED"
|
||||
```
|
||||
|
||||
If unresolved > 0, the agent is NOT done — re-brief with the actual count and the rule.
|
||||
|
||||
**Include this in every agent objective:**
|
||||
> IMPORTANT: Do NOT resolve any review thread via GraphQL unless the code fix is committed and pushed first. Fix the code → commit → push → reply with SHA → then resolve. Never resolve without a real commit. "Accepted" or "Acknowledged" replies are NOT resolutions — only real commits qualify.
|
||||
|
||||
### Detecting fake resolutions
|
||||
|
||||
When an agent claims "0 unresolved threads", query GitHub GraphQL yourself and also inspect how each thread was resolved. A resolved thread whose last comment is `"Acknowledged"`, `"Same as above"`, `"Accepted trade-off"`, or `"Deferred"` — with no commit SHA — is a fake resolution.
|
||||
|
||||
To spot these, paginate all pages and collect resolved threads with missing SHA links:
|
||||
```bash
|
||||
# Paginate all pages — first:100 misses threads beyond page 1 on large PRs
|
||||
CURSOR=""; FAKE_RESOLUTIONS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: PR_NUMBER) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
isResolved
|
||||
comments(last: 1) {
|
||||
nodes { body author { login } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
PAGE_FAKES=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[]
|
||||
| select(.isResolved == true)
|
||||
| {body: .comments.nodes[0].body[:120], author: .comments.nodes[0].author.login}
|
||||
| select(.body | test("Fixed in|Removed in|Addressed in") | not)]')
|
||||
FAKE_RESOLUTIONS=$(echo "$FAKE_RESOLUTIONS $PAGE_FAKES" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "$FAKE_RESOLUTIONS"
|
||||
```
|
||||
Any resolved thread whose last comment does NOT contain `"Fixed in"`, `"Removed in"`, or `"Addressed in"` (with a commit link) should be investigated — either the agent falsely resolved it, or it was a genuine false positive that needs explanation.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Two distinct rate limits exist with different recovery times:
|
||||
|
||||
| Error | HTTP status | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` in body | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `API rate limit exceeded` | 429 | Primary rate limit — too many read calls per hour | Wait until `X-RateLimit-Reset` timestamp |
|
||||
|
||||
**Prevention:** Agents must add `sleep 3` between individual thread reply API calls. For >20 unresolved threads, increase to `sleep 5`.
|
||||
|
||||
If you see a 403 `abuse` error from an agent's pane:
|
||||
1. Nudge the agent: `"You hit a GitHub secondary rate limit (403). Stop all API writes. Wait 2 minutes, then resume with sleep 3 between each thread reply."`
|
||||
2. Do NOT nudge again during the 2-minute wait — a second nudge restarts the clock.
|
||||
|
||||
Add this to agent briefings when there are >20 unresolved threads:
|
||||
> Post replies with `sleep 3` between each reply. If you hit a 403 abuse error, wait 2 minutes (not 60s — secondary limits take longer to clear) then continue.
|
||||
|
||||
## Key rules
|
||||
|
||||
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
|
||||
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
|
||||
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
|
||||
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
|
||||
5. **Always `--permission-mode bypassPermissions`** on every spawn
|
||||
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
|
||||
7. **Atomic state writes** — always write to `.tmp` then `mv`
|
||||
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
|
||||
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
|
||||
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
|
||||
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
|
||||
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
|
||||
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
|
||||
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
|
||||
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
|
||||
16. **Poll ALL windows from state file** — never hardcode window count. Derive active windows dynamically: `jq -r '.agents[] | select(.state | test("running|idle|stuck")) | .window' ~/.claude/orchestrator-state.json`. If you added a window mid-session outside spawn-agent.sh, add it to the state file immediately.
|
||||
20. **Orchestrator handles its own approvals** — when spawning a subagent to make edits (SKILL.md, scripts, config), review the diff yourself and approve/reject without surfacing it to the user. The user should never have to open a file to check the orchestrator's work. Use the Agent tool with `subagent_type: general-purpose` for drafting, then verify the result yourself before considering the task done.
|
||||
17. **Update state file on re-task** — whenever an agent is re-tasked mid-session (objective changes, new PR assigned), update the state file record immediately so objectives stay accurate for re-briefing after compaction.
|
||||
18. **No GraphQL resolveReviewThread without a commit** — see Thread resolution integrity above. This is rule #1 for pr-address work.
|
||||
19. **Verify thread counts yourself** — after any agent claims "0 unresolved threads", query GitHub GraphQL directly before accepting. Never trust the agent's self-report.
|
||||
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
|
||||
#
|
||||
# Usage: capacity.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree of the current git repo.
|
||||
#
|
||||
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
|
||||
echo "=== Available (spare) worktrees ==="
|
||||
if [ -n "$REPO_ROOT" ]; then
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
|
||||
else
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$SPARE" ]; then
|
||||
echo " (none)"
|
||||
else
|
||||
while IFS= read -r line; do
|
||||
[ -z "$line" ] && continue
|
||||
echo " ✓ $line"
|
||||
done <<< "$SPARE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== In-use worktrees ==="
|
||||
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
|
||||
"$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -n "$IN_USE" ]; then
|
||||
echo "$IN_USE"
|
||||
else
|
||||
echo " (none)"
|
||||
fi
|
||||
else
|
||||
echo " (no active state file)"
|
||||
fi
|
||||
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env bash
|
||||
# classify-pane.sh — Classify the current state of a tmux pane
|
||||
#
|
||||
# Usage: classify-pane.sh <tmux-target>
|
||||
# tmux-target: e.g. "work:0", "work:1.0"
|
||||
#
|
||||
# Output (stdout): JSON object:
|
||||
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
|
||||
#
|
||||
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TARGET="${1:-}"
|
||||
|
||||
if [ -z "$TARGET" ]; then
|
||||
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate tmux target format: session:window or session:window.pane
|
||||
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check session exists (use %%:* to extract session name from session:window)
|
||||
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
|
||||
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the current foreground command in the pane
|
||||
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
|
||||
|
||||
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
|
||||
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
|
||||
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
|
||||
| grep -v '^[[:space:]]*$' || true)
|
||||
|
||||
# --- Check: explicit completion marker ---
|
||||
# Must be on its own line (not buried in the objective text sent at spawn time).
|
||||
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
|
||||
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Check: Claude Code approval prompt patterns ---
|
||||
LAST_40=$(echo "$CLEAN" | tail -40)
|
||||
APPROVAL_PATTERNS=(
|
||||
"Do you want to proceed"
|
||||
"Do you want to make this"
|
||||
"\\[y/n\\]"
|
||||
"\\[Y/n\\]"
|
||||
"\\[n/Y\\]"
|
||||
"Proceed\\?"
|
||||
"Allow this command"
|
||||
"Run bash command"
|
||||
"Allow bash"
|
||||
"Would you like"
|
||||
"Press enter to continue"
|
||||
"Esc to cancel"
|
||||
)
|
||||
for pattern in "${APPROVAL_PATTERNS[@]}"; do
|
||||
if echo "$LAST_40" | grep -qiE "$pattern"; then
|
||||
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
|
||||
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Check: shell prompt (claude has exited) ---
|
||||
# If the foreground process is a shell (not claude/node), the agent has exited
|
||||
case "$PANE_CMD" in
|
||||
zsh|bash|fish|sh|dash|tcsh|ksh)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
# Agent is still running (claude/node/python is the foreground process)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
# find-spare.sh — list worktrees on spare/N branches (free to use)
|
||||
#
|
||||
# Usage: find-spare.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree containing the current git repo.
|
||||
#
|
||||
# Output (stdout): one line per available worktree: "PATH BRANCH"
|
||||
# e.g.: /Users/me/Code/AutoGPT3 spare/3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
if [ -z "$REPO_ROOT" ]; then
|
||||
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git -C "$REPO_ROOT" worktree list --porcelain \
|
||||
| awk '
|
||||
/^worktree / { path = substr($0, 10) }
|
||||
/^branch / { branch = substr($0, 8); print path " " branch }
|
||||
' \
|
||||
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
|
||||
| sed 's|refs/heads/||'
|
||||
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
# notify.sh — send a fleet notification message
|
||||
#
|
||||
# Delivery order (first available wins):
|
||||
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
|
||||
# 2. macOS notification center — osascript (silent fail if unavailable)
|
||||
# 3. Stdout only
|
||||
#
|
||||
# Usage: notify.sh MESSAGE
|
||||
# Exit: always 0 (notification failure must not abort the caller)
|
||||
|
||||
MESSAGE="${1:-}"
|
||||
[ -z "$MESSAGE" ] && exit 0
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# --- Resolve Discord webhook ---
|
||||
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
|
||||
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
|
||||
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# --- Discord delivery ---
|
||||
if [ -n "$WEBHOOK" ]; then
|
||||
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
|
||||
curl -s -X POST "$WEBHOOK" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" > /dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
# --- macOS notification center (silent if not macOS or osascript missing) ---
|
||||
if command -v osascript &>/dev/null 2>&1; then
|
||||
# Escape single quotes for AppleScript
|
||||
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
|
||||
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Always print to stdout so run-loop.sh logs it
|
||||
echo "$MESSAGE"
|
||||
exit 0
|
||||
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env bash
|
||||
# poll-cycle.sh — Single orchestrator poll cycle
|
||||
#
|
||||
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
|
||||
# and outputs a JSON array of actions for Claude to take.
|
||||
#
|
||||
# Usage: poll-cycle.sh
|
||||
# Output (stdout): JSON array of action objects
|
||||
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
|
||||
# "worktree": "...", "objective": "...", "reason": "..." }]
|
||||
#
|
||||
# The state file is updated in-place (atomic write via .tmp).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
|
||||
|
||||
# Cross-platform md5: always outputs just the hex digest
|
||||
md5_hash() {
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum | awk '{print $1}'
|
||||
else
|
||||
md5 | awk '{print $NF}'
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
|
||||
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
|
||||
|
||||
# Ensure state file exists
|
||||
if [ ! -f "$STATE_FILE" ]; then
|
||||
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Validate JSON upfront before any jq reads that run under set -e.
|
||||
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
|
||||
# abort the script at the ACTIVE read below without emitting any JSON output.
|
||||
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
|
||||
if [ "$ACTIVE" != "true" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NOW=$(date +%s)
|
||||
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
|
||||
|
||||
ACTIONS="[]"
|
||||
UPDATED_AGENTS="[]"
|
||||
|
||||
# Read agents as newline-delimited JSON objects.
|
||||
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
|
||||
# so we suppress that exit code and separately validate the file is well-formed JSON.
|
||||
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
|
||||
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
fi
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "$AGENTS_JSON" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while IFS= read -r agent; do
|
||||
[ -z "$agent" ] && continue
|
||||
|
||||
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
|
||||
WINDOW=$(echo "$agent" | jq -r '.window // ""')
|
||||
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
|
||||
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
|
||||
STATE=$(echo "$agent" | jq -r '.state // "running"')
|
||||
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
|
||||
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
|
||||
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
|
||||
|
||||
# Validate window format to prevent tmux target injection.
|
||||
# Allow session:window (numeric or named) and session:window.pane
|
||||
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo "Skipping agent with invalid window value: $WINDOW" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Pass-through terminal-state agents
|
||||
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Classify pane.
|
||||
# classify-pane.sh always emits JSON before exit (even on error), so using
|
||||
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
|
||||
# Use "|| true" inside the substitution so set -euo pipefail does not abort
|
||||
# the poll cycle when classify exits with a non-zero status code.
|
||||
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
|
||||
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
|
||||
|
||||
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
|
||||
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
|
||||
|
||||
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
|
||||
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
|
||||
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
|
||||
|
||||
# --- Checkpoint tracking ---
|
||||
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
|
||||
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
|
||||
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
|
||||
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
|
||||
if [ -n "$RAW" ]; then
|
||||
FOUND_CPS=$(echo "$RAW" \
|
||||
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
|
||||
| sed 's/CHECKPOINT://' \
|
||||
| sort -u \
|
||||
| jq -R . | jq -s . 2>/dev/null || echo "[]")
|
||||
NEW_CHECKPOINTS_JSON=$(jq -n \
|
||||
--argjson existing "$EXISTING_CPS" \
|
||||
--argjson found "$FOUND_CPS" \
|
||||
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
|
||||
fi
|
||||
|
||||
# Compute content hash for stuck-detection (only for running agents)
|
||||
CURRENT_HASH=""
|
||||
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
|
||||
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
|
||||
fi
|
||||
|
||||
NEW_STATE="$STATE"
|
||||
NEW_IDLE_SINCE="$IDLE_SINCE"
|
||||
NEW_REVISION_COUNT="$REVISION_COUNT"
|
||||
ACTION="none"
|
||||
REASON="$PANE_REASON"
|
||||
|
||||
case "$PANE_STATE" in
|
||||
complete)
|
||||
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
|
||||
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
|
||||
NEW_STATE="pending_evaluation"
|
||||
ACTION="complete" # run-loop logs it but takes no action
|
||||
;;
|
||||
waiting_approval)
|
||||
NEW_STATE="waiting_approval"
|
||||
ACTION="approve"
|
||||
;;
|
||||
idle)
|
||||
# Agent process has exited — needs restart
|
||||
NEW_STATE="idle"
|
||||
ACTION="kick"
|
||||
REASON="agent exited (shell is foreground)"
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
fi
|
||||
;;
|
||||
running)
|
||||
# Clear idle_since only when transitioning from idle (agent was kicked and
|
||||
# restarted). Do NOT reset for stuck — idle_since must persist across polls
|
||||
# so STUCK_DURATION can accumulate and trigger escalation.
|
||||
# Also update the local IDLE_SINCE so the hash-stability check below uses
|
||||
# the reset value on this same poll, not the stale kick timestamp.
|
||||
if [[ "$STATE" == "idle" ]]; then
|
||||
NEW_IDLE_SINCE=0
|
||||
IDLE_SINCE=0
|
||||
fi
|
||||
# Check if hash has been stable (agent may be stuck mid-task)
|
||||
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
|
||||
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
else
|
||||
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
|
||||
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
else
|
||||
NEW_STATE="stuck"
|
||||
ACTION="kick"
|
||||
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# Only reset the idle timer when we have a valid hash comparison (pane
|
||||
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
|
||||
# preserve existing timers so stuck detection is not inadvertently reset.
|
||||
if [ -n "$CURRENT_HASH" ]; then
|
||||
NEW_STATE="running"
|
||||
NEW_IDLE_SINCE=0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
error)
|
||||
REASON="classify error: $PANE_REASON"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build updated agent record (ensure idle_since and revision_count are numeric)
|
||||
# Use || true on each jq call so a malformed field skips this agent rather than
|
||||
# aborting the entire poll cycle under set -e.
|
||||
UPDATED_AGENT=$(echo "$agent" | jq \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg hash "$CURRENT_HASH" \
|
||||
--argjson now "$NOW" \
|
||||
--arg idle_since "$NEW_IDLE_SINCE" \
|
||||
--arg revision_count "$NEW_REVISION_COUNT" \
|
||||
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
|
||||
'.state = $state
|
||||
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
|
||||
| .last_seen_at = $now
|
||||
| .idle_since = ($idle_since | tonumber)
|
||||
| .revision_count = ($revision_count | tonumber)
|
||||
| .checkpoints = $checkpoints' 2>/dev/null) || {
|
||||
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
}
|
||||
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
|
||||
|
||||
# Add action if needed
|
||||
if [ "$ACTION" != "none" ]; then
|
||||
ACTION_OBJ=$(jq -n \
|
||||
--arg window "$WINDOW" \
|
||||
--arg action "$ACTION" \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg worktree "$WORKTREE" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg reason "$REASON" \
|
||||
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
|
||||
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
|
||||
fi
|
||||
|
||||
done <<< "$AGENTS_JSON"
|
||||
|
||||
# Atomic state file update
|
||||
jq --argjson agents "$UPDATED_AGENTS" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents = $agents | .last_poll_at = $now' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
|
||||
echo "$ACTIONS"
|
||||
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
|
||||
#
|
||||
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
|
||||
# WINDOW — tmux target, e.g. autogpt1:3
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — branch to restore, e.g. spare/6
|
||||
#
|
||||
# Stdout: one status line
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WINDOW="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
|
||||
# Kill the tmux window (ignore error — may already be gone)
|
||||
tmux kill-window -t "$WINDOW" 2>/dev/null || true
|
||||
|
||||
# Restore to spare branch: abort any in-progress operation, then clean
|
||||
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
|
||||
|
||||
echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)"
|
||||
215
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
215
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env bash
|
||||
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
|
||||
#
|
||||
# Handles ONLY two things that need no intelligence:
|
||||
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
|
||||
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
|
||||
#
|
||||
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
|
||||
# marking done, deciding to close windows — is the orchestrating Claude's job.
|
||||
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
|
||||
# the orchestrator's background poll loop handles it from there.
|
||||
#
|
||||
# Usage: run-loop.sh
|
||||
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Copy scripts to a stable location outside the repo so they survive branch
|
||||
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
|
||||
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
|
||||
# current branch).
|
||||
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
|
||||
mkdir -p "$STABLE_SCRIPTS_DIR"
|
||||
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
|
||||
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
|
||||
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
# Adaptive polling: starts at base interval, backs off up to POLL_IDLE_MAX when
|
||||
# no agents need attention, resets on any activity or waiting_approval state.
|
||||
POLL_INTERVAL="${POLL_INTERVAL:-30}"
|
||||
POLL_IDLE_MAX=${POLL_IDLE_MAX:-300}
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_state WINDOW FIELD VALUE
|
||||
# ---------------------------------------------------------------------------
|
||||
update_state() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --arg v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
update_state_int() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
agent_field() {
|
||||
jq -r --arg w "$1" --arg f "$2" \
|
||||
'.agents[] | select(.window == $w) | .[$f] // ""' \
|
||||
"$STATE_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_prompt() {
|
||||
local window="$1"
|
||||
for i in $(seq 1 60); do
|
||||
local cmd pane
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter; sleep 2; continue
|
||||
fi
|
||||
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0
|
||||
sleep 1
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_claude_idle WINDOW — wait up to 30s for Claude to reach idle ❯ prompt
|
||||
# (no spinner or busy indicator visible in the last 3 lines of pane output)
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_claude_idle() {
|
||||
local window="$1"
|
||||
local timeout="${2:-30}"
|
||||
local elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can scroll above last 3 lines.
|
||||
# Do NOT reset elapsed — resetting allows an infinite loop if the dialog never clears.
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
# Must be running under node (Claude is live)
|
||||
if [[ "$cmd" == "node" ]]; then
|
||||
# Idle: ❯ prompt visible AND no spinner/busy text in last 3 lines
|
||||
if echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
sleep 2
|
||||
(( elapsed += 2 ))
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_kick() {
|
||||
local window="$1" state="$2"
|
||||
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
|
||||
|
||||
local worktree_path session_id
|
||||
worktree_path=$(agent_field "$window" "worktree_path")
|
||||
session_id=$(agent_field "$window" "session_id")
|
||||
|
||||
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
|
||||
|
||||
# Wait for the shell prompt before typing — avoids sending into a still-draining pane
|
||||
wait_for_claude_idle "$window" 30 \
|
||||
|| echo "[$(date +%H:%M:%S)] KICK WARNING $window — pane still busy before resume, sending anyway"
|
||||
|
||||
# Resume the exact session so the agent retains full context — no need to re-send objective
|
||||
if [ -n "$session_id" ]; then
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
|
||||
else
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
|
||||
fi
|
||||
|
||||
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_approve WINDOW — auto-approve dialogs that need no judgment
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_approve() {
|
||||
local window="$1"
|
||||
local pane_tail
|
||||
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
|
||||
|
||||
# Settings error dialog at startup
|
||||
if echo "$pane_tail" | grep -q "Enter to confirm"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Numbered-option dialog (e.g. "Do you want to make this edit?")
|
||||
# ❯ is already on option 1 (Yes) — Enter confirms it
|
||||
if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
|
||||
tmux send-keys -t "$window" "" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# y/n prompt for safe operations
|
||||
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
|
||||
tmux send-keys -t "$window" "y" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Anything else — supervisor handles it, just log
|
||||
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll ${POLL_INTERVAL}s→${POLL_IDLE_MAX}s adaptive)"
|
||||
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
|
||||
echo "---"
|
||||
|
||||
while true; do
|
||||
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "[$(date +%H:%M:%S)] active=false — exiting."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
|
||||
KICKED=0; DONE=0
|
||||
|
||||
while IFS= read -r action; do
|
||||
[ -z "$action" ] && continue
|
||||
WINDOW=$(echo "$action" | jq -r '.window // ""')
|
||||
ACTION=$(echo "$action" | jq -r '.action // ""')
|
||||
STATE=$(echo "$action" | jq -r '.state // ""')
|
||||
|
||||
case "$ACTION" in
|
||||
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
|
||||
approve) handle_approve "$WINDOW" || true ;;
|
||||
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
|
||||
esac
|
||||
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
|
||||
|
||||
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
|
||||
"$STATE_FILE" 2>/dev/null || echo 0)
|
||||
|
||||
# Adaptive backoff: reset to base on activity or waiting_approval agents; back off when truly idle
|
||||
WAITING=$(jq '[.agents[] | select(.state == "waiting_approval")] | length' "$STATE_FILE" 2>/dev/null || echo 0)
|
||||
if (( KICKED > 0 || DONE > 0 || WAITING > 0 )); then
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
else
|
||||
POLL_CURRENT=$(( POLL_CURRENT + POLL_CURRENT / 2 + 1 ))
|
||||
(( POLL_CURRENT > POLL_IDLE_MAX )) && POLL_CURRENT=$POLL_IDLE_MAX
|
||||
fi
|
||||
|
||||
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled (next in ${POLL_CURRENT}s)"
|
||||
sleep "$POLL_CURRENT"
|
||||
done
|
||||
129
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
129
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env bash
|
||||
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
|
||||
#
|
||||
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
|
||||
# SESSION — tmux session name, e.g. autogpt1
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
|
||||
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
|
||||
# OBJECTIVE — task description sent to the agent
|
||||
# PR_NUMBER — (optional) GitHub PR number for completion verification
|
||||
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
|
||||
#
|
||||
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
|
||||
# Exit non-zero on failure.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SESSION="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
NEW_BRANCH="$4"
|
||||
OBJECTIVE="$5"
|
||||
PR_NUMBER="${6:-}"
|
||||
STEPS=("${@:7}")
|
||||
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# Generate a stable session ID so this agent's Claude session can always be resumed:
|
||||
# claude --resume $SESSION_ID --permission-mode bypassPermissions
|
||||
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
|
||||
|
||||
# Create (or switch to) the task branch
|
||||
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|
||||
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
|
||||
|
||||
# Open a new named tmux window; capture its numeric index
|
||||
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
|
||||
WINDOW="${SESSION}:${WIN_IDX}"
|
||||
|
||||
# Append the initial agent record to the state file so subsequent jq updates find it.
|
||||
# This must happen before the pr_number/steps update below.
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
NOW=$(date +%s)
|
||||
jq --arg window "$WINDOW" \
|
||||
--arg worktree "$WORKTREE_NAME" \
|
||||
--arg worktree_path "$WORKTREE_PATH" \
|
||||
--arg spare_branch "$SPARE_BRANCH" \
|
||||
--arg branch "$NEW_BRANCH" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg session_id "$SESSION_ID" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents += [{
|
||||
"window": $window,
|
||||
"worktree": $worktree,
|
||||
"worktree_path": $worktree_path,
|
||||
"spare_branch": $spare_branch,
|
||||
"branch": $branch,
|
||||
"objective": $objective,
|
||||
"session_id": $session_id,
|
||||
"state": "running",
|
||||
"checkpoints": [],
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": $now,
|
||||
"spawned_at": $now,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
|
||||
# The agent record was appended above so the jq select now finds it.
|
||||
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
|
||||
if [ "${#STEPS[@]}" -gt 0 ]; then
|
||||
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
|
||||
else
|
||||
STEPS_JSON='[]'
|
||||
fi
|
||||
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
|
||||
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Launch claude with a stable session ID so it can always be resumed after a crash:
|
||||
# claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
|
||||
|
||||
# wait_for_claude_idle — poll until the pane shows idle ❯ with no spinner in the last 3 lines.
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
_wait_idle() {
|
||||
local window="$1" timeout="${2:-60}" elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can appear above the last 3 lines
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
if [[ "$cmd" == "node" ]] && \
|
||||
echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
sleep 2; (( elapsed += 2 ))
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
# Wait up to 60s for claude to be fully interactive and idle (❯ visible, no spinner).
|
||||
if ! _wait_idle "$WINDOW" 60; then
|
||||
echo "[spawn-agent] WARNING: timed out waiting for idle ❯ prompt on $WINDOW — sending objective anyway" >&2
|
||||
fi
|
||||
|
||||
# Send the task. Split text and Enter — if combined, Enter can fire before the string
|
||||
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
|
||||
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# Only output the window address — nothing else (callers parse this)
|
||||
echo "$WINDOW"
|
||||
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# status.sh — print orchestrator status: state file summary + live tmux pane commands
|
||||
#
|
||||
# Usage: status.sh
|
||||
# Reads: ~/.claude/orchestrator-state.json
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "No orchestrator state found at $STATE_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Header: active status, session, thresholds, last poll
|
||||
jq -r '
|
||||
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
|
||||
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
|
||||
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
|
||||
""
|
||||
' "$STATE_FILE"
|
||||
|
||||
# Each agent: state, window, worktree/branch, truncated objective
|
||||
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
|
||||
if [ "$AGENT_COUNT" -eq 0 ]; then
|
||||
echo " (no agents registered)"
|
||||
else
|
||||
jq -r '
|
||||
.agents[] |
|
||||
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
|
||||
" \(.objective // "" | .[0:70])"
|
||||
' "$STATE_FILE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# Live pane_current_command for non-done agents
|
||||
while IFS= read -r WINDOW; do
|
||||
[ -z "$WINDOW" ] && continue
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
|
||||
echo " $WINDOW live: $CMD"
|
||||
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)
|
||||
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env bash
|
||||
# verify-complete.sh — verify a PR task is truly done before marking the agent done
|
||||
#
|
||||
# Check order matters:
|
||||
# 1. Checkpoints — did the agent do all required steps?
|
||||
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
|
||||
# 3. CI passing — no failures (agent must fix before done)
|
||||
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
|
||||
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
|
||||
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
|
||||
#
|
||||
# Usage: verify-complete.sh WINDOW
|
||||
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WINDOW="$1"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
|
||||
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
|
||||
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
|
||||
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
|
||||
|
||||
# No PR number = cannot verify
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 1: all required steps are checkpointed ---
|
||||
MISSING=""
|
||||
while IFS= read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
|
||||
MISSING="$MISSING $step"
|
||||
fi
|
||||
done <<< "$STEPS"
|
||||
|
||||
if [ -n "$MISSING" ]; then
|
||||
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve repo for all GitHub checks below
|
||||
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
|
||||
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
|
||||
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$REPO" ]; then
|
||||
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
|
||||
|
||||
# --- Check 2: CI fully complete — no pending checks ---
|
||||
# Pending checks MUST finish before we check threads/reviews:
|
||||
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
|
||||
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$PENDING" -gt 0 ]; then
|
||||
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 3: CI passing — no failures ---
|
||||
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$FAILING" -gt 0 ]; then
|
||||
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
|
||||
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
|
||||
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
|
||||
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
|
||||
if [ -n "$LATEST_RUN_AT" ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
|
||||
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
OWNER=$(echo "$REPO" | cut -d/ -f1)
|
||||
REPONAME=$(echo "$REPO" | cut -d/ -f2)
|
||||
|
||||
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
|
||||
UNRESOLVED=$(gh api graphql -f query="
|
||||
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
|
||||
pullRequest(number: ${PR_NUMBER}) {
|
||||
reviewThreads(first: 50) { nodes { isResolved } }
|
||||
}
|
||||
}
|
||||
}
|
||||
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
|
||||
|
||||
if [ "$UNRESOLVED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
|
||||
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
|
||||
# Stale reviews (pre-dating the fixing commits) should not block verification.
|
||||
#
|
||||
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
|
||||
# treat that as NOT COMPLETE rather than silently passing.
|
||||
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
|
||||
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
|
||||
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
|
||||
# PR activity (bot comments, label changes) which would create false negatives.
|
||||
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
|
||||
--json commits,latestReviews 2>/dev/null) || {
|
||||
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
|
||||
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
|
||||
|
||||
BLOCKING_CHANGES_REQUESTED=0
|
||||
BLOCKING_REQUESTERS=""
|
||||
|
||||
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
|
||||
while IFS= read -r review; do
|
||||
[ -z "$review" ] && continue
|
||||
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
|
||||
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
|
||||
if [ -z "$REVIEW_DATE" ]; then
|
||||
# No submission date — treat as fresh (conservative: blocks verification)
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
else
|
||||
if date --version >/dev/null 2>&1; then
|
||||
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
|
||||
# Review was submitted AFTER latest commit — still fresh, blocks verification
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
fi
|
||||
# Review submitted BEFORE latest commit — stale, skip
|
||||
fi
|
||||
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
|
||||
else
|
||||
# No commit date or no changes_requested — check raw count as fallback
|
||||
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
|
||||
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
|
||||
exit 0
|
||||
@@ -29,30 +29,71 @@ gh pr view {N} --json body --jq '.body'
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
> ⚠️ **WARNING — PAGINATE ALL PAGES BEFORE ADDRESSING ANYTHING**
|
||||
>
|
||||
> `reviewThreads(first: 100)` returns at most 100 threads per page. A PR with many review cycles can have 140+ threads across 2+ pages. **If you start addressing threads after fetching only page 1, you will miss all threads on subsequent pages and silently leave them unresolved.**
|
||||
>
|
||||
> PR #12636 had 142 total threads: page 1 returned 69 unresolved, page 2 had 42 more (111 total unresolved). An agent that stopped after page 1 addressed only 69 and falsely reported "done".
|
||||
>
|
||||
> **The rule: collect ALL thread IDs from ALL pages into a single list, then address them.**
|
||||
|
||||
**Step 1 — Fetch total count first:**
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
reviewThreads { totalCount }
|
||||
}
|
||||
}
|
||||
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
|
||||
```
|
||||
|
||||
If `totalCount > 100`, you have multiple pages. Fetch them all before doing anything else.
|
||||
|
||||
**Step 2 — Collect all unresolved thread IDs across all pages:**
|
||||
|
||||
```bash
|
||||
# Accumulate all unresolved threads — loop until hasNextPage == false
|
||||
CURSOR=""
|
||||
ALL_THREADS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
line
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
}")
|
||||
# Append unresolved nodes from this page
|
||||
PAGE_THREADS=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)]')
|
||||
ALL_THREADS=$(echo "$ALL_THREADS $PAGE_THREADS" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
|
||||
echo "Total unresolved threads: $(echo "$ALL_THREADS" | jq 'length')"
|
||||
echo "$ALL_THREADS" | jq '[.[] | {id, path, line, body: .comments.nodes[0].body[:200]}]'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
**Step 3 — Address every thread in `ALL_THREADS`, then resolve.**
|
||||
|
||||
Only after this loop completes (all pages fetched, count confirmed) should you begin making fixes.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
@@ -84,16 +125,43 @@ Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`gi
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
**CRITICAL: The only valid sequence is fix → commit → push → reply → resolve. Never resolve a thread without a real code commit.**
|
||||
|
||||
Resolving a thread via `resolveReviewThread` without an actual fix is the most common failure mode — it makes unresolved counts drop without any real change, producing a false "done" signal. If the issue was genuinely a false positive (no code change needed), reply explaining why and then resolve. Otherwise:
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → resolve.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
Use a **markdown commit link** so GitHub renders it as a clickable reference. Always get the full SHA with `git rev-parse HEAD` **after** committing — never copy a SHA from a previous commit or hardcode one:
|
||||
|
||||
```bash
|
||||
FULL_SHA=$(git rev-parse HEAD)
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): <description>"
|
||||
```
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
|
||||
### What counts as a valid resolution
|
||||
|
||||
Only two situations justify calling `resolveReviewThread`:
|
||||
|
||||
1. **Real code fix**: you changed the code, committed + pushed, and replied with the SHA. The commit diff must actually address the concern — not just touch the same file.
|
||||
2. **Genuine false positive**: the reviewer's concern does not apply to this code, and you can give a specific technical reason (e.g. "Not applicable — `sdk_cwd` is pre-validated by `_make_sdk_cwd()` which applies normpath + prefix assertion before reaching this point").
|
||||
|
||||
**Anti-patterns that look resolved but aren't — never do these:**
|
||||
- `"Accepted, tracked as follow-up"` — a deferral, not a fix. The concern is still open. Do not resolve.
|
||||
- `"Acknowledged"` or `"Same as above"` — these are acknowledgements, not fixes. Do not resolve.
|
||||
- `"Fixed in abc1234"` where `abc1234` is a commit that doesn't actually change the flagged line/logic — dishonest. Verify `git show abc1234 -- path/to/file` changes the right thing before posting.
|
||||
- Resolving without replying — the reviewer never sees what happened.
|
||||
|
||||
When in doubt: if a code change is needed, make it. A deferred issue means the thread stays open until the follow-up PR is merged.
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
@@ -141,6 +209,22 @@ Then commit and **push immediately** — never batch commits without pushing. Ea
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## Coverage
|
||||
|
||||
Codecov enforces patch coverage on new/changed lines — new code you write must be tested. Before pushing, verify you haven't left new lines uncovered:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run pytest --cov=. --cov-report=term-missing {path/to/changed/module}
|
||||
```
|
||||
|
||||
Look for lines marked `miss` — those are uncovered. Add tests for any new code you wrote as part of addressing comments.
|
||||
|
||||
**Rules:**
|
||||
- New code you add should have tests
|
||||
- Don't remove existing tests when fixing comments
|
||||
- If a reviewer asks you to delete code, also delete its tests, but verify coverage hasn't dropped on remaining lines
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
@@ -230,3 +314,113 @@ git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Two distinct rate limits exist — they have different causes and recovery times:
|
||||
|
||||
| Error | HTTP code | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
|
||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||
|
||||
**Recovery from secondary rate limit (403):**
|
||||
1. Stop all API writes immediately
|
||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||
3. Resume with `sleep 3` between each call
|
||||
4. If 403 persists after 2 min, wait another 2 min before retrying
|
||||
|
||||
Never batch all replies in a tight loop — always space them out.
|
||||
|
||||
## Parallel thread resolution
|
||||
|
||||
When a PR has more than 10 unresolved threads, addressing one commit per thread is slow. Use this strategy instead:
|
||||
|
||||
### Group by file, batch per commit
|
||||
|
||||
1. Sort `ALL_THREADS` by `path` — threads in the same file can share a single commit.
|
||||
2. Fix all threads in one file → `git commit` → `git push` → reply to **all** those threads with the same SHA → resolve them all.
|
||||
3. Move to the next file group and repeat.
|
||||
|
||||
This reduces N commits to (number of files touched), which is usually 3–5 instead of 15–30.
|
||||
|
||||
### Posting replies concurrently (for large batches)
|
||||
|
||||
For truly independent thread groups (different files, no shared logic), you can post replies in parallel using background subshells — but always space out API writes:
|
||||
|
||||
```bash
|
||||
# Post replies to a batch of threads concurrently, 3s apart
|
||||
(
|
||||
sleep 3
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID1}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
(
|
||||
sleep 6
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID2}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
wait # wait for all background replies before resolving
|
||||
```
|
||||
|
||||
Then resolve sequentially (GraphQL mutations):
|
||||
```bash
|
||||
for THREAD_ID in "$THREAD1" "$THREAD2" "$THREAD3"; do
|
||||
gh api graphql -f query="mutation { resolveReviewThread(input: {threadId: \"${THREAD_ID}\"}) { thread { isResolved } } }"
|
||||
sleep 3
|
||||
done
|
||||
```
|
||||
|
||||
**Always sleep 3s between individual API writes** — GitHub's secondary rate limit (403) triggers on bursts of >20 writes. Increase to `sleep 5` when posting more than 20 replies in a batch.
|
||||
|
||||
## Resolving threads via GraphQL
|
||||
|
||||
Use `resolveReviewThread` **only after** the commit is pushed and the reply is posted:
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREAD_ID"}) { thread { isResolved } } }'
|
||||
```
|
||||
|
||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||
|
||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||
|
||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||
|
||||
```bash
|
||||
# Step 1: get total thread count
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads { totalCount }
|
||||
}
|
||||
}
|
||||
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
|
||||
|
||||
# Step 2: paginate all pages, count truly unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes { isResolved }
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved threads: $UNRESOLVED"
|
||||
```
|
||||
|
||||
Only output `ORCHESTRATOR:DONE` after this loop reports 0.
|
||||
|
||||
@@ -310,6 +310,28 @@ TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
### 3i. Disable onboarding for test user
|
||||
|
||||
The frontend redirects to `/onboarding` when the `VISIT_COPILOT` step is not in `completedSteps`.
|
||||
Mark it complete via the backend API so every browser test lands on the real feature UI:
|
||||
|
||||
```bash
|
||||
ONBOARDING_RESULT=$(curl -s --max-time 30 -X POST \
|
||||
"http://localhost:8006/api/onboarding/step?step=VISIT_COPILOT" \
|
||||
-H "Authorization: Bearer $TOKEN")
|
||||
echo "Onboarding bypass: $ONBOARDING_RESULT"
|
||||
|
||||
# Verify it took effect
|
||||
ONBOARDING_STATUS=$(curl -s --max-time 30 \
|
||||
"http://localhost:8006/api/onboarding/completed" \
|
||||
-H "Authorization: Bearer $TOKEN" | jq -r '.is_completed')
|
||||
echo "Onboarding completed: $ONBOARDING_STATUS"
|
||||
if [ "$ONBOARDING_STATUS" != "true" ]; then
|
||||
echo "ERROR: onboarding bypass failed — browser tests will hit /onboarding instead of the target feature. Investigate before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
@@ -547,6 +569,8 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
@@ -584,15 +608,27 @@ TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Resolve parent commit so screenshots are chained, not orphan root commits
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
|
||||
if [ -n "$PARENT_SHA" ]; then
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
-f "parents[]=$PARENT_SHA" \
|
||||
--jq '.sha')
|
||||
else
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
fi
|
||||
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
-X PATCH -f sha="$COMMIT_SHA" -F force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
@@ -658,6 +694,15 @@ INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
|
||||
# Verify the posted comment contains inline images — exit 1 if none found
|
||||
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
|
||||
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
|
||||
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
|
||||
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Inline images verified in posted comment"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
@@ -667,6 +712,103 @@ rm -f "$COMMENT_FILE"
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Step 8: Evaluate and post a formal PR review
|
||||
|
||||
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
|
||||
|
||||
### Evaluation criteria
|
||||
|
||||
Re-read the PR description:
|
||||
```bash
|
||||
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
|
||||
```
|
||||
|
||||
Score the run against each criterion:
|
||||
|
||||
| Criterion | Pass condition |
|
||||
|-----------|---------------|
|
||||
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
|
||||
| **All scenarios pass** | No FAIL rows in the results table |
|
||||
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
|
||||
| **Before/after evidence** | Every state-changing API call has before/after values logged |
|
||||
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
|
||||
| **No regressions** | Existing core flows (login, agent create/run) still work |
|
||||
|
||||
### Decision logic
|
||||
|
||||
```
|
||||
ALL criteria pass → APPROVE
|
||||
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
|
||||
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
|
||||
```
|
||||
|
||||
### Post the review
|
||||
|
||||
```bash
|
||||
REVIEW_FILE=$(mktemp)
|
||||
|
||||
# Count results
|
||||
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
|
||||
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
|
||||
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
|
||||
|
||||
# List any coverage gaps found during evaluation (populate this array as you assess)
|
||||
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
|
||||
COVERAGE_GAPS=()
|
||||
```
|
||||
|
||||
**If APPROVING** — all criteria met, zero failures, full coverage:
|
||||
|
||||
```bash
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — APPROVED
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
|
||||
|
||||
**Coverage:** All features described in the PR were exercised.
|
||||
|
||||
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
|
||||
|
||||
**Negative tests:** Failure paths tested for each feature.
|
||||
|
||||
No regressions observed on core flows.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
|
||||
echo "✅ PR approved"
|
||||
```
|
||||
|
||||
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
|
||||
|
||||
```bash
|
||||
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
|
||||
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — Changes Requested
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
|
||||
|
||||
### Required before merge
|
||||
|
||||
${FAIL_LIST}
|
||||
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
|
||||
|
||||
Please fix the above and re-run the E2E tests.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
|
||||
echo "❌ Changes requested"
|
||||
```
|
||||
|
||||
```bash
|
||||
rm -f "$REVIEW_FILE"
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
|
||||
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
|
||||
- Never request changes for issues already fixed in this run
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
return await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import PlatformCostDashboard
|
||||
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_repeated_requests(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Repeated requests to the dashboard route both return 200."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
r1 = client.get("/platform-costs/dashboard")
|
||||
r2 = client.get("/platform-costs/dashboard")
|
||||
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r1.json()["total_cost_microdollars"] == 42
|
||||
assert r2.json()["total_cost_microdollars"] == 42
|
||||
@@ -16,6 +16,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -155,6 +156,8 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
@@ -394,60 +397,78 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
messages = [message.model_dump() for message in page.messages]
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in page.session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
metadata=page.session.metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
|
||||
file_count: int
|
||||
|
||||
|
||||
class WorkspaceFileItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
files: list[WorkspaceFileItem]
|
||||
offset: int = 0
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
operation_id="getWorkspaceDownloadFileById",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -158,6 +175,7 @@ async def download_file(
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
operation_id="deleteWorkspaceFile",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -183,6 +201,7 @@ async def delete_workspace_file(
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
operation_id="uploadWorkspaceFile",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -196,6 +215,9 @@ async def upload_file(
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
# Empty-string session_id drops session scoping; normalize to None.
|
||||
session_id = session_id or None
|
||||
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
@@ -250,16 +272,27 @@ async def upload_file(
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
try:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to soft-delete over-quota file {workspace_file.id} "
|
||||
f"in workspace {workspace.id}: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
@@ -281,6 +314,7 @@ async def upload_file(
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
operation_id="getWorkspaceStorageUsage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -301,3 +335,57 @@ async def get_storage_usage(
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files",
|
||||
summary="List workspace files",
|
||||
operation_id="listWorkspaceFiles",
|
||||
)
|
||||
async def list_workspace_files(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
session_id: str | None = Query(default=None),
|
||||
limit: int = Query(default=200, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
List files in the user's workspace.
|
||||
|
||||
When session_id is provided, only files for that session are returned.
|
||||
Otherwise, all files across sessions are listed. Results are paginated
|
||||
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
|
||||
"""
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Treat empty-string session_id the same as omitted — an empty value
|
||||
# would otherwise silently list files across every session instead of
|
||||
# scoping to one.
|
||||
session_id = session_id or None
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
include_all = session_id is None
|
||||
# Fetch one extra to compute has_more without a separate count query.
|
||||
files = await manager.list_files(
|
||||
limit=limit + 1,
|
||||
offset=offset,
|
||||
include_all_sessions=include_all,
|
||||
)
|
||||
has_more = len(files) > limit
|
||||
page = files[:limit]
|
||||
|
||||
return ListFilesResponse(
|
||||
files=[
|
||||
WorkspaceFileItem(
|
||||
id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
metadata=f.metadata or {},
|
||||
created_at=f.created_at.isoformat(),
|
||||
)
|
||||
for f in page
|
||||
],
|
||||
offset=offset,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
@@ -1,48 +1,28 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.api.features.workspace.routes import router
|
||||
from backend.data.workspace import Workspace, WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
"""Mirror the production ValueError → 400 mapping from the REST app."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
|
||||
return Workspace(
|
||||
id="ws-001",
|
||||
user_id=user_id,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"storage_path": "local://test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _make_file_mock(**overrides) -> MagicMock:
|
||||
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"metadata": {},
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock(spec=WorkspaceFile)
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
# -- list_workspace_files tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
|
||||
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["files"]) == 2
|
||||
assert data["has_more"] is False
|
||||
assert data["offset"] == 0
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_scopes_to_session_when_provided(
|
||||
mock_manager_cls, mock_get_workspace, test_user_id
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?session_id=sess-123")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["files"] == []
|
||||
assert data["has_more"] is False
|
||||
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=False
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["files"][0]["metadata"] == {}
|
||||
|
||||
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_instance.write_file.assert_called_once()
|
||||
call_kwargs = mock_instance.write_file.call_args
|
||||
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.side_effect = ValueError("File already exists at path")
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("dup.txt", b"content", "text/plain")},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
# -- Restored upload/download/delete security + invariant tests --
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
_MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-001",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
name="hello.txt",
|
||||
path="/sessions/sess-1/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_happy_path(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_exceeds_max_file_size(mocker):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
def test_upload_post_write_quota_race(mocker):
|
||||
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_any_extension(mocker):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_file_without_extension(mocker):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_strips_path_components(mocker):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_download_file_not_found(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_success(mocker):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_not_found(mocker):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_no_workspace(mocker):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"""write_file raises ValueError("File too large: …") → must map to 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "File too large" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"""Non-'File too large' ValueErrors from write_file stay as 409."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.text
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_true_when_limit_exceeded(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
# Backend was asked for limit+1=3, and returned exactly 3 items.
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt"),
|
||||
_make_file(id="f2", name="b.txt"),
|
||||
_make_file(id="f3", name="c.txt"),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is True
|
||||
assert len(data["files"]) == 2
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=3, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_false_when_exactly_page_size(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Exactly `limit` rows means we're on the last page — has_more=False."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
assert len(data["files"]) == 2
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?offset=50&limit=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["offset"] == 50
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -0,0 +1,712 @@
|
||||
"""Unit tests for merge_stats cost tracking in individual blocks.
|
||||
|
||||
Covers the exa code_context, exa contents, and apollo organization blocks
|
||||
to verify provider cost is correctly extracted and reported.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_EXA_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_EXA_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_EXA_CREDENTIALS.provider,
|
||||
"id": TEST_EXA_CREDENTIALS.id,
|
||||
"type": TEST_EXA_CREDENTIALS.type,
|
||||
"title": TEST_EXA_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaCodeContextBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_float_cost(self):
|
||||
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "how to use hooks",
|
||||
"response": "Here are some examples...",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="how to use hooks",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_dollars_does_not_raise(self):
|
||||
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "query",
|
||||
"response": "response",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
merge_calls: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merge_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_is_tracked(self):
|
||||
"""A zero cost_dollars string '0.0' should still be recorded."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "query",
|
||||
"response": "...",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaContentsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_cost_dollars_total(self):
|
||||
"""provider_cost equals response.cost_dollars.total when present."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = cost_dollars
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_cost_dollars_absent(self):
|
||||
"""When response.cost_dollars is None, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = None
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchOrganizationsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_org_count(self):
|
||||
"""provider_cost == number of returned organizations, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Organization
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
|
||||
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_orgs,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(3.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_org_list_tracks_zero(self):
|
||||
"""An empty organization list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JinaEmbeddingBlock — token count from usage.total_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJinaEmbeddingBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_token_count(self):
|
||||
"""provider token count is recorded when API returns usage.total_tokens."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
"usage": {"total_tokens": 42},
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello world"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].input_token_count == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_usage_absent(self):
|
||||
"""When API response omits usage field, merge_stats is not called."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UnrealTextToSpeechBlock — character count from input text length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
test_text = "Hello, world!"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text=test_text,
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text="",
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GoogleMapsSearchBlock — item count from search_places results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleMapsSearchBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_place_count(self):
|
||||
"""provider_cost equals number of returned places, type == 'items'."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
|
||||
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=fake_places,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="coffee shops",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 4.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_tracks_zero(self):
|
||||
"""Zero places returned results in provider_cost=0.0."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="nothing here",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SmartLeadAddLeadsBlock — item count from lead_list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartLeadAddLeadsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_lead_count(self):
|
||||
"""provider_cost equals number of leads uploaded, type == 'items'."""
|
||||
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsToCampaignResponse,
|
||||
LeadInput,
|
||||
)
|
||||
|
||||
block = AddLeadToCampaignBlock()
|
||||
|
||||
fake_leads = [
|
||||
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
|
||||
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
|
||||
]
|
||||
fake_response = AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=2,
|
||||
total_leads=2,
|
||||
block_count=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
invalid_emails=[],
|
||||
already_added_to_campaign=0,
|
||||
unsubscribed_leads=[],
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
bounce_count=0,
|
||||
)
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
AddLeadToCampaignBlock,
|
||||
"add_leads_to_campaign",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = AddLeadToCampaignBlock.Input(
|
||||
campaign_id=123,
|
||||
lead_list=fake_leads,
|
||||
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=SL_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 2.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPeopleBlock — item count from people list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchPeopleBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_people_count(self):
|
||||
"""provider_cost equals number of returned people, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_people,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(5.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_people_list_tracks_zero(self):
|
||||
"""An empty people list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
"""Tests for cost tracking in Exa blocks.
|
||||
|
||||
Covers the cost_dollars → provider_cost → merge_stats path for both
|
||||
ExaContentsBlock and ExaCodeContextBlock.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
class TestExaCodeContextCostTracking:
|
||||
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_cost_string_is_parsed_and_merged(self):
|
||||
"""A numeric cost string like '0.005' is merged as provider_cost."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "test query",
|
||||
"response": "some code",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
assert any(k == "cost_dollars" for k, _ in outputs)
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_string_does_not_raise(self):
|
||||
"""A non-numeric cost_dollars value is swallowed silently."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "test",
|
||||
"response": "code",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
# No merge_stats call because float() raised ValueError
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_string_is_merged(self):
|
||||
"""'0.0' is a valid cost — should still be tracked."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "free query",
|
||||
"response": "result",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaContentsCostTracking:
|
||||
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_dollars_is_merged(self):
|
||||
"""A total of 0.0 (free tier) should still be merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCreateResearchBlock — cost_dollars from completed poll response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMPLETED_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "completed",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
"finishedAt": 1700000060000,
|
||||
"costDollars": {
|
||||
"total": 0.05,
|
||||
"numSearches": 3,
|
||||
"numPages": 10,
|
||||
"reasoningTokens": 500,
|
||||
},
|
||||
"output": {"content": "Research findings...", "parsed": None},
|
||||
}
|
||||
|
||||
PENDING_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "pending",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestExaCreateResearchBlockCostTracking:
|
||||
"""ExaCreateResearchBlock merges cost from completed poll response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total when poll returns completed."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed response has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaGetResearchBlock — cost_dollars from single GET response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaGetResearchBlockCostTracking:
|
||||
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_from_completed_research(self):
|
||||
"""merge_stats called with provider_cost=total when research has costDollars."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaWaitForResearchBlock — cost_dollars from polling response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaWaitForResearchBlockCostTracking:
|
||||
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total once polling returns completed."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(places)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
@@ -13,6 +14,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -737,6 +739,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -771,6 +774,35 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
|
||||
attribute. We use try/except AttributeError so that if the SDK ever drops
|
||||
or renames that attribute, the warning is visible in logs rather than
|
||||
silently degrading to no cost tracking.
|
||||
"""
|
||||
try:
|
||||
raw_resp = response._response # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"OpenAI SDK response missing _response attribute"
|
||||
" — OpenRouter cost tracking unavailable"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if not cost_header:
|
||||
return None
|
||||
cost = float(cost_header)
|
||||
if not math.isfinite(cost) or cost < 0:
|
||||
return None
|
||||
return cost
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
@@ -1103,6 +1135,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=extract_openrouter_cost(response),
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1410,6 +1443,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
last_attempt_cost: float | None = None
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
@@ -1427,12 +1461,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
# Merge token counts for every attempt (each call costs tokens).
|
||||
# provider_cost (actual USD) is tracked separately and only merged
|
||||
# on success to avoid double-counting across retries.
|
||||
token_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
self.merge_stats(token_stats)
|
||||
last_attempt_cost = llm_response.provider_cost
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
@@ -1501,6 +1538,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1521,6 +1559,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
|
||||
@@ -251,8 +251,13 @@ def _convert_raw_response_to_dict(
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
elif _is_responses_api_object(raw_response):
|
||||
# OpenAI Responses API: extract individual output items
|
||||
items = [json.to_dict(item) for item in raw_response.output]
|
||||
# OpenAI Responses API: extract individual output items.
|
||||
# Strip 'status' — it's a response-only field that OpenAI rejects
|
||||
# when the item is sent back as input on the next API call.
|
||||
items = [
|
||||
{k: v for k, v in json.to_dict(item).items() if k != "status"}
|
||||
for item in raw_response.output
|
||||
]
|
||||
return items if items else [{"role": "assistant", "content": ""}]
|
||||
else:
|
||||
# Chat Completions / Anthropic return message objects
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.lead_list)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
|
||||
@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_uses_last_attempt_only(self):
|
||||
"""provider_cost is only merged from the final successful attempt.
|
||||
|
||||
Intermediate retry costs are intentionally dropped to avoid
|
||||
double-counting: the cost of failed attempts is captured in
|
||||
last_attempt_cost only when the loop eventually succeeds.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Only the final successful attempt's cost is merged
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
@@ -987,3 +1047,67 @@ class TestLlmModelMissing:
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
|
||||
class TestExtractOpenRouterCost:
|
||||
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
|
||||
|
||||
def _mk_response(self, headers: dict | None):
|
||||
response = MagicMock()
|
||||
if headers is None:
|
||||
response._response = None
|
||||
else:
|
||||
raw = MagicMock()
|
||||
raw.headers = headers
|
||||
response._response = raw
|
||||
return response
|
||||
|
||||
def test_extracts_numeric_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "0.0042"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0042
|
||||
|
||||
def test_returns_none_when_header_missing(self):
|
||||
response = self._mk_response({})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_empty_string(self):
|
||||
response = self._mk_response({"x-total-cost": ""})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_non_numeric(self):
|
||||
response = self._mk_response({"x-total-cost": "not-a-number"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_no_response_attr(self):
|
||||
response = MagicMock(spec=[]) # no _response attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_is_none(self):
|
||||
response = self._mk_response(None)
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_has_no_headers(self):
|
||||
response = MagicMock()
|
||||
response._response = MagicMock(spec=[]) # no headers attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_zero_for_zero_cost(self):
|
||||
"""Zero-cost is a valid value (free tier) and must not become None."""
|
||||
response = self._mk_response({"x-total-cost": "0"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0
|
||||
|
||||
def test_returns_none_for_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_negative_inf(self):
|
||||
response = self._mk_response({"x-total-cost": "-inf"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_nan(self):
|
||||
response = self._mk_response({"x-total-cost": "nan"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_for_negative_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "-0.005"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
@@ -211,6 +211,30 @@ class TestConvertRawResponseToDict:
|
||||
# A single dict is wrong — there are two distinct items
|
||||
pytest.fail("Expected a list of output items, got a single dict")
|
||||
|
||||
def test_responses_api_strips_status_from_function_call(self):
|
||||
"""Responses API function_call items have a 'status' field that OpenAI
|
||||
rejects when sent back as input ('Unknown parameter: input[N].status').
|
||||
It must be stripped before the item is stored in conversation history."""
|
||||
resp = _MockResponse(
|
||||
output=[_MockFunctionCall("my_tool", '{"x": 1}', call_id="call_xyz")]
|
||||
)
|
||||
result = _convert_raw_response_to_dict(resp)
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert (
|
||||
"status" not in item
|
||||
), f"'status' must be stripped from Responses API items: {item}"
|
||||
|
||||
def test_responses_api_strips_status_from_message(self):
|
||||
"""Responses API message items also carry 'status'; it must be stripped."""
|
||||
resp = _MockResponse(output=[_MockOutputMessage("Hello")])
|
||||
result = _convert_raw_response_to_dict(resp)
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert (
|
||||
"status" not in item
|
||||
), f"'status' must be stripped from Responses API items: {item}"
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# _get_tool_requests (lines 61-86)
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.text)),
|
||||
provider_cost_type="characters",
|
||||
)
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -22,6 +23,7 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
from opentelemetry import trace as otel_trace
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.context import get_workspace_manager, set_execution_context
|
||||
@@ -30,7 +32,6 @@ from backend.copilot.model import (
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
maybe_append_user_message,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
@@ -51,8 +52,8 @@ from backend.copilot.response_model import (
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_get_openai_client,
|
||||
_update_title_async,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
@@ -334,6 +335,7 @@ class _BaselineStreamState:
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
cost_usd: float | None = None
|
||||
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
|
||||
session_messages: list[ChatMessage] = field(default_factory=list)
|
||||
|
||||
@@ -354,6 +356,7 @@ async def _baseline_llm_caller(
|
||||
state.thinking_stripper = _ThinkingStripper()
|
||||
|
||||
round_text = ""
|
||||
response = None # initialized before try so finally block can access it
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
@@ -430,6 +433,20 @@ async def _baseline_llm_caller(
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Extract OpenRouter cost from response headers (in finally so we
|
||||
# capture cost even when the stream errors mid-way — we already paid).
|
||||
# Accumulate across multi-round tool-calling turns.
|
||||
try:
|
||||
# Access undocumented _response attribute — same pattern as
|
||||
# extract_openrouter_cost() in blocks/llm.py.
|
||||
cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined]
|
||||
if cost_header:
|
||||
cost = float(cost_header)
|
||||
if math.isfinite(cost) and cost >= 0:
|
||||
state.cost_usd = (state.cost_usd or 0.0) + cost
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
@@ -686,18 +703,6 @@ def _baseline_conversation_updater(
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
model: str,
|
||||
@@ -1183,8 +1188,22 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
# Set cost attributes on OTEL span before closing
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.prompt_tokens", state.turn_prompt_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens",
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
if state.cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd)
|
||||
except Exception:
|
||||
logger.debug("[Baseline] Failed to set OTEL cost attributes")
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
@@ -1226,6 +1245,8 @@ async def stream_chat_completion_baseline(
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
cost_usd=state.cost_usd,
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
# Persist structured tool-call history (assistant + tool messages)
|
||||
|
||||
@@ -4,7 +4,7 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
@@ -631,3 +631,200 @@ class TestPrepareBaselineAttachments:
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
|
||||
class TestBaselineCostExtraction:
|
||||
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_extracted_from_response_header(self):
|
||||
"""state.cost_usd is set from x-total-cost header when present."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
# Build a mock raw httpx response with the cost header
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {"x-total-cost": "0.0123"}
|
||||
|
||||
# Build a mock async streaming response that yields no chunks but has
|
||||
# a _response attribute pointing to the mock httpx response
|
||||
mock_stream_response = MagicMock()
|
||||
mock_stream_response._response = mock_raw_response
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream_response.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_stream_response
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.0123)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_accumulates_across_calls(self):
|
||||
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
def make_stream_mock(cost: str) -> MagicMock:
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": cost}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
return mock_stream
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "first"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "second"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.03)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_header_absent(self):
|
||||
"""state.cost_usd remains None when response has no x-total-cost header."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_extracted_even_when_stream_raises(self):
|
||||
"""cost_usd is captured in the finally block even when streaming fails."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.005"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def failing_aiter():
|
||||
raise RuntimeError("stream error")
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream.__aiter__ = lambda self: failing_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
pytest.raises(RuntimeError, match="stream error"),
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_api_call_raises_before_stream(self):
|
||||
"""finally block is safe when response is None (API call failed before yielding)."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=RuntimeError("connection refused")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
pytest.raises(RuntimeError, match="connection refused"),
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
# response was never assigned so cost extraction must not raise
|
||||
assert state.cost_usd is None
|
||||
|
||||
@@ -14,6 +14,7 @@ from prisma.types import (
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
@@ -30,6 +31,15 @@ from .model import get_chat_session as get_chat_session_cached
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaginatedMessages(BaseModel):
|
||||
"""Result of a paginated message query."""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
has_more: bool
|
||||
oldest_sequence: int | None
|
||||
session: ChatSessionInfo
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
@@ -39,6 +49,116 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def get_chat_session_metadata(session_id: str) -> ChatSessionInfo | None:
|
||||
"""Get chat session metadata (without messages) for ownership validation."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
)
|
||||
return ChatSessionInfo.from_db(session) if session else None
|
||||
|
||||
|
||||
async def get_chat_messages_paginated(
|
||||
session_id: str,
|
||||
limit: int = 50,
|
||||
before_sequence: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PaginatedMessages | None:
|
||||
"""Get paginated messages for a session, newest first.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided)
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
limit: Max messages to return.
|
||||
before_sequence: Cursor — return messages with sequence < this value.
|
||||
user_id: If provided, filters via ``Session.userId`` so only the
|
||||
session owner's messages are returned (acts as an ownership guard).
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
session_where["userId"] = user_id
|
||||
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
msg_include: dict[str, Any] = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
if before_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"lt": before_sequence}}
|
||||
|
||||
# Single query: session existence/ownership + paginated messages
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where=session_where,
|
||||
include={"Messages": msg_include},
|
||||
)
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
session_info = ChatSessionInfo.from_db(session)
|
||||
results = list(session.Messages) if session.Messages else []
|
||||
|
||||
has_more = len(results) > limit
|
||||
results = results[:limit]
|
||||
|
||||
# Reverse to ascending order
|
||||
results.reverse()
|
||||
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
|
||||
return PaginatedMessages(
|
||||
messages=messages,
|
||||
has_more=has_more,
|
||||
oldest_sequence=oldest_sequence,
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -1,7 +1,341 @@
|
||||
import pytest
|
||||
"""Unit tests for copilot.db — paginated message queries."""
|
||||
|
||||
from .db import set_turn_duration
|
||||
from .model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
|
||||
from backend.copilot.db import (
|
||||
PaginatedMessages,
|
||||
get_chat_messages_paginated,
|
||||
set_turn_duration,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage as CopilotChatMessage
|
||||
from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session
|
||||
|
||||
|
||||
def _make_msg(
|
||||
sequence: int,
|
||||
role: str = "assistant",
|
||||
content: str | None = "hello",
|
||||
tool_calls: Any = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Build a minimal PrismaChatMessage for testing."""
|
||||
return PrismaChatMessage(
|
||||
id=f"msg-{sequence}",
|
||||
createdAt=datetime.now(UTC),
|
||||
sessionId="sess-1",
|
||||
role=role,
|
||||
content=content,
|
||||
sequence=sequence,
|
||||
toolCalls=tool_calls,
|
||||
name=None,
|
||||
toolCallId=None,
|
||||
refusal=None,
|
||||
functionCall=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_session(
|
||||
session_id: str = "sess-1",
|
||||
user_id: str = "user-1",
|
||||
messages: list[PrismaChatMessage] | None = None,
|
||||
) -> PrismaChatSession:
|
||||
"""Build a minimal PrismaChatSession for testing."""
|
||||
now = datetime.now(UTC)
|
||||
session = PrismaChatSession.model_construct(
|
||||
id=session_id,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
userId=user_id,
|
||||
credentials={},
|
||||
successfulAgentRuns={},
|
||||
successfulAgentSchedules={},
|
||||
totalPromptTokens=0,
|
||||
totalCompletionTokens=0,
|
||||
title=None,
|
||||
metadata={},
|
||||
Messages=messages or [],
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
SESSION_ID = "sess-1"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_db():
|
||||
"""Patch ChatSession.prisma().find_first and ChatMessage.prisma().find_many.
|
||||
|
||||
find_first is used for the main query (session + included messages).
|
||||
find_many is used only for boundary expansion queries.
|
||||
"""
|
||||
with (
|
||||
patch.object(PrismaChatSession, "prisma") as mock_session_prisma,
|
||||
patch.object(PrismaChatMessage, "prisma") as mock_msg_prisma,
|
||||
):
|
||||
find_first = AsyncMock()
|
||||
mock_session_prisma.return_value.find_first = find_first
|
||||
|
||||
find_many = AsyncMock(return_value=[])
|
||||
mock_msg_prisma.return_value.find_many = find_many
|
||||
|
||||
yield find_first, find_many
|
||||
|
||||
|
||||
# ---------- Basic pagination ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_page_returns_messages_ascending(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Messages are returned in ascending sequence order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert isinstance(page, PaginatedMessages)
|
||||
assert [m.sequence for m in page.messages] == [1, 2, 3]
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_more_when_results_exceed_limit(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""has_more is True when DB returns more than limit items."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
assert len(page.messages) == 2
|
||||
assert [m.sequence for m in page.messages] == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_returns_no_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[])
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.messages == []
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_sequence_filters_correctly(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""before_sequence is passed as a where filter inside the Messages include."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(2), _make_msg(1)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, before_sequence=5)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["where"] == {"sequence": {"lt": 5}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_where_on_messages_without_before_sequence(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Without before_sequence, the Messages include has no where clause."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_filter_applied_to_session_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""user_id adds a userId filter to the session-level where clause."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, user_id="user-abc")
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
where = call_kwargs.kwargs.get("where") or call_kwargs[1].get("where")
|
||||
assert where["userId"] == "user-abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_found_returns_none(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Returns None when session doesn't exist or user doesn't own it."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = None
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_info_included_in_result(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""PaginatedMessages includes session metadata."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(messages=[_make_msg(1)])
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.session.session_id == SESSION_ID
|
||||
|
||||
|
||||
# ---------- Backward boundary expansion ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_includes_assistant(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When page starts with a tool message, expand backward to include
|
||||
the owning assistant message."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(3, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [3, 4, 5]
|
||||
assert page.messages[0].role == "assistant"
|
||||
assert page.oldest_sequence == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_includes_multiple_tool_msgs(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Boundary expansion scans past consecutive tool messages to find
|
||||
the owning assistant."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(7, role="tool")],
|
||||
)
|
||||
find_many.return_value = [
|
||||
_make_msg(6, role="tool"),
|
||||
_make_msg(5, role="tool"),
|
||||
_make_msg(4, role="assistant"),
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [4, 5, 6, 7]
|
||||
assert page.messages[0].role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_sets_has_more_when_not_at_start(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""After boundary expansion, has_more=True if expanded msgs aren't at seq 0."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(2, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_no_has_more_at_conversation_start(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""has_more stays False when boundary expansion reaches seq 0."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(1, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(0, role="assistant")]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is False
|
||||
assert page.oldest_sequence == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_boundary_expansion_when_first_msg_not_tool(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""No boundary expansion when the first message is not a tool message."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(3, role="user"), _make_msg(2, role="assistant")],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
|
||||
assert page is not None
|
||||
assert find_many.call_count == 0
|
||||
assert [m.sequence for m in page.messages] == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_expansion_warns_when_no_owner_found(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When boundary scan doesn't find a non-tool message, a warning is logged
|
||||
and the boundary messages are still included."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(10, role="tool")],
|
||||
)
|
||||
find_many.return_value = [_make_msg(i, role="tool") for i in range(9, -1, -1)]
|
||||
|
||||
with patch("backend.copilot.db.logger") as mock_logger:
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "tool"
|
||||
assert len(page.messages) > 1
|
||||
|
||||
|
||||
# ---------- Turn duration (integration tests) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -15,8 +349,8 @@ async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_us
|
||||
"""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi there"),
|
||||
CopilotChatMessage(role="user", content="hello"),
|
||||
CopilotChatMessage(role="assistant", content="hi there"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
@@ -41,7 +375,7 @@ async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user
|
||||
"""set_turn_duration is a no-op when there are no assistant messages."""
|
||||
session = ChatSession.new(user_id=test_user_id, dry_run=False)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
CopilotChatMessage(role="user", content="hello"),
|
||||
]
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
|
||||
@@ -151,8 +151,8 @@ class CoPilotProcessor:
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
DB operations route through DatabaseManagerAsyncClient (RPC) via the
|
||||
db_accessors pattern — no direct Prisma connection is needed here.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
|
||||
@@ -64,6 +64,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
sequence: int | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
@@ -77,6 +78,7 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
sequence=prisma_message.sequence,
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
@@ -409,9 +410,12 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except Exception:
|
||||
raise _UserNotFoundError(user_id)
|
||||
if user.subscription_tier:
|
||||
return SubscriptionTier(user.subscription_tier)
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
|
||||
@@ -401,66 +401,49 @@ class TestGetUserTier:
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
def _mock_user_db(
|
||||
self, subscription_tier: str | None = None, raises: Exception | None = None
|
||||
):
|
||||
"""Return a patched user_db() whose get_user_by_id behaves as specified."""
|
||||
mock_db = AsyncMock()
|
||||
if raises is not None:
|
||||
mock_db.get_user_by_id = AsyncMock(side_effect=raises)
|
||||
else:
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscription_tier = subscription_tier
|
||||
mock_db.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
return mock_db
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tier_from_db(self):
|
||||
"""Should return the tier stored in the user record."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
mock_db = self._mock_user_db(subscription_tier="PRO")
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_user_not_found(self):
|
||||
"""Should return DEFAULT_TIER when user is not in the DB."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
mock_db = self._mock_user_db(raises=Exception("not found"))
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_tier_is_none(self):
|
||||
"""Should return DEFAULT_TIER when subscriptionTier is None."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = None
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
"""Should return DEFAULT_TIER when subscription_tier is None."""
|
||||
mock_db = self._mock_user_db(subscription_tier=None)
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_db_error(self):
|
||||
"""Should fall back to DEFAULT_TIER when DB raises."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
mock_db = self._mock_user_db(raises=Exception("DB down"))
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -470,26 +453,14 @@ class TestGetUserTier:
|
||||
Regression test: a transient DB failure previously cached DEFAULT_TIER
|
||||
for 5 minutes, incorrectly downgrading higher-tier users until expiry.
|
||||
"""
|
||||
failing_prisma = AsyncMock()
|
||||
failing_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=failing_prisma,
|
||||
):
|
||||
failing_db = self._mock_user_db(raises=Exception("DB down"))
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=failing_db):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Now DB recovers and returns PRO
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
ok_db = self._mock_user_db(subscription_tier="PRO")
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=ok_db):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO now — the error result was not cached
|
||||
@@ -498,18 +469,9 @@ class TestGetUserTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_invalid_tier_value(self):
|
||||
"""Should fall back to DEFAULT_TIER when stored value is invalid."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "invalid-tier"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
mock_db = self._mock_user_db(subscription_tier="invalid-tier")
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -522,26 +484,14 @@ class TestGetUserTier:
|
||||
stale cached FREE tier for up to 5 minutes.
|
||||
"""
|
||||
# First call: user does not exist yet
|
||||
missing_prisma = AsyncMock()
|
||||
missing_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=missing_prisma,
|
||||
):
|
||||
missing_db = self._mock_user_db(raises=Exception("not found"))
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=missing_db):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Second call: user now exists with PRO tier
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
ok_db = self._mock_user_db(subscription_tier="PRO")
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=ok_db):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO — the not-found result was not cached
|
||||
@@ -598,20 +548,19 @@ class TestSetUserTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidated_after_set(self):
|
||||
"""After set_user_tier, get_user_tier should query DB again (not cache)."""
|
||||
# First, populate the cache with BUSINESS
|
||||
# First, populate the cache with BUSINESS via user_db() mock
|
||||
mock_db_biz = AsyncMock()
|
||||
mock_user_biz = MagicMock()
|
||||
mock_user_biz.subscriptionTier = "BUSINESS"
|
||||
mock_prisma_get = AsyncMock()
|
||||
mock_prisma_get.find_unique = AsyncMock(return_value=mock_user_biz)
|
||||
mock_user_biz.subscription_tier = "BUSINESS"
|
||||
mock_db_biz.get_user_by_id = AsyncMock(return_value=mock_user_biz)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get,
|
||||
):
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db_biz):
|
||||
tier_before = await get_user_tier(_USER)
|
||||
assert tier_before == SubscriptionTier.BUSINESS
|
||||
|
||||
# Now set tier to ENTERPRISE (this should invalidate the cache)
|
||||
# Now set tier to ENTERPRISE via PrismaUser.prisma (set_user_tier still
|
||||
# uses Prisma directly since it's only called from admin API where Prisma
|
||||
# is connected).
|
||||
mock_prisma_set = AsyncMock()
|
||||
mock_prisma_set.update = AsyncMock(return_value=None)
|
||||
|
||||
@@ -622,15 +571,12 @@ class TestSetUserTier:
|
||||
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
# Now get_user_tier should hit DB again (cache was invalidated)
|
||||
mock_db_ent = AsyncMock()
|
||||
mock_user_ent = MagicMock()
|
||||
mock_user_ent.subscriptionTier = "ENTERPRISE"
|
||||
mock_prisma_get2 = AsyncMock()
|
||||
mock_prisma_get2.find_unique = AsyncMock(return_value=mock_user_ent)
|
||||
mock_user_ent.subscription_tier = "ENTERPRISE"
|
||||
mock_db_ent.get_user_by_id = AsyncMock(return_value=mock_user_ent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get2,
|
||||
):
|
||||
with patch("backend.copilot.rate_limit.user_db", return_value=mock_db_ent):
|
||||
tier_after = await get_user_tier(_USER)
|
||||
|
||||
assert tier_after == SubscriptionTier.ENTERPRISE
|
||||
|
||||
@@ -29,6 +29,7 @@ from claude_agent_sdk import (
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
@@ -64,7 +65,6 @@ from ..model import (
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
maybe_append_user_message,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
@@ -83,11 +83,7 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..service import _build_system_prompt, _is_langfuse_configured, _update_title_async
|
||||
from ..token_tracking import persist_and_record_usage
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
@@ -2372,8 +2368,26 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# --- Close OTEL context ---
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens", turn_completion_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_creation_tokens",
|
||||
turn_cache_creation_tokens,
|
||||
)
|
||||
if turn_cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
|
||||
except Exception:
|
||||
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
|
||||
try:
|
||||
_otel_ctx.__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
@@ -2391,6 +2405,8 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
@@ -2495,18 +2511,3 @@ async def stream_chat_completion_sdk(
|
||||
finally:
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||
|
||||
@@ -22,7 +22,12 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
from .model import (
|
||||
ChatSessionInfo,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -202,6 +207,22 @@ async def _generate_session_title(
|
||||
return None
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background.
|
||||
|
||||
Shared by both the SDK and baseline execution paths.
|
||||
"""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("Generated title for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update session title for %s: %s", session_id, e)
|
||||
|
||||
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -4,17 +4,85 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set[asyncio.Task[None]] = set()
|
||||
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
|
||||
# fire from the event loop thread; drain_pending_cost_logs iterates the set
|
||||
# from any caller — the lock prevents RuntimeError from concurrent modification.
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
|
||||
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task[None]) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
|
||||
# block/credential in the block_cost_config or credentials_store tables).
|
||||
COPILOT_BLOCK_ID = "copilot"
|
||||
COPILOT_CREDENTIAL_ID = "copilot_system"
|
||||
|
||||
|
||||
def _copilot_block_name(log_prefix: str) -> str:
|
||||
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
|
||||
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
|
||||
if match:
|
||||
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
|
||||
tag = log_prefix.strip(" []")
|
||||
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
|
||||
|
||||
|
||||
async def persist_and_record_usage(
|
||||
*,
|
||||
@@ -26,6 +94,8 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: int = 0,
|
||||
log_prefix: str = "",
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -38,6 +108,7 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -47,12 +118,13 @@ async def persist_and_record_usage(
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if (
|
||||
no_tokens = (
|
||||
prompt_tokens <= 0
|
||||
and completion_tokens <= 0
|
||||
and cache_read_tokens <= 0
|
||||
and cache_creation_tokens <= 0
|
||||
):
|
||||
)
|
||||
if no_tokens and cost_usd is None:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
@@ -73,14 +145,14 @@ async def persist_and_record_usage(
|
||||
|
||||
if cache_read_tokens or cache_creation_tokens:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
|
||||
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
|
||||
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, cache_read={cache_read_tokens},"
|
||||
f" cache_create={cache_creation_tokens}, output={completion_tokens},"
|
||||
f" total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
|
||||
f"completion={completion_tokens}, total={total_tokens}"
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
|
||||
f" total={total_tokens}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
@@ -93,6 +165,54 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard.
|
||||
# Include entries where cost_usd is set even if token count is 0
|
||||
# (e.g. fully-cached Anthropic responses where only cache tokens
|
||||
# accumulate a charge without incrementing total_tokens).
|
||||
if user_id and (total_tokens > 0 or cost_usd is not None):
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
val = float(cost_usd)
|
||||
if math.isfinite(val) and val >= 0:
|
||||
cost_float = val
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = usd_to_microdollars(cost_float)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
tracking_type = "cost_usd"
|
||||
tracking_amount = cost_float
|
||||
else:
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
_schedule_cost_log(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
block_id=COPILOT_BLOCK_ID,
|
||||
block_name=_copilot_block_name(log_prefix),
|
||||
provider=provider,
|
||||
credential_id=COPILOT_CREDENTIAL_ID,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=model,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata={
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"source": "copilot",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@@ -4,6 +4,7 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
||||
calling conventions, session persistence, and rate-limit recording.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -279,3 +280,290 @@ class TestRateLimitRecording:
|
||||
completion_tokens=0,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlatformCostLog integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlatformCostLogging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_with_cost_usd(self):
|
||||
"""When cost_usd is provided, tracking_type should be 'cost_usd'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=_make_session(),
|
||||
user_id="user-cost",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cost_usd=0.005,
|
||||
model="gpt-4",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cost"
|
||||
assert entry.provider == "anthropic"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 200
|
||||
assert entry.output_tokens == 100
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["tracking_amount"] == 0.005
|
||||
assert entry.block_name == "copilot:SDK"
|
||||
assert entry.graph_exec_id == "sess-test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_without_cost_usd(self):
|
||||
"""When cost_usd is None, tracking_type should be 'tokens'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-tokens",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 150
|
||||
assert entry.graph_exec_id is None
|
||||
assert entry.block_name == "copilot:Baseline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_no_user_id(self):
|
||||
"""No PlatformCostLog entry when user_id is None."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_invalid_string_falls_back_to_tokens(self):
|
||||
"""Invalid cost_usd string should fall back to tokens tracking."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-invalid",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="not-a-number",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_string_number_is_parsed(self):
|
||||
"""String-encoded cost_usd (e.g. from OpenRouter) should be parsed."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-str",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="0.01",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars == 10_000
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_log_prefix_produces_copilot_block_name(self):
|
||||
"""Empty log_prefix results in block_name='copilot'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-empty",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
log_prefix="",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.block_name == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_included_in_metadata(self):
|
||||
"""Cache token counts should be present in the metadata."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cache",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=300,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.metadata["cache_read_tokens"] == 5000
|
||||
assert entry.metadata["cache_creation_tokens"] == 300
|
||||
assert entry.metadata["source"] == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_only_when_tokens_zero(self):
|
||||
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cached",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cost_usd=0.005,
|
||||
model="claude-3-5-sonnet",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Guard: total_tokens == 0 but cost_usd is set — must still log
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cached"
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 0
|
||||
assert entry.output_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_cost_usd_falls_back_to_tokens(self):
|
||||
"""Negative cost_usd must be rejected — val >= 0 guard in persist_and_record_usage."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-negative",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=-0.01,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
# Negative cost rejected — falls back to token-based tracking
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
|
||||
@@ -845,6 +845,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
metadata={"origin": "agent-created"},
|
||||
)
|
||||
|
||||
# Build informative source label and message.
|
||||
|
||||
@@ -142,3 +142,16 @@ def credit_db():
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
|
||||
def platform_cost_db():
|
||||
if db.is_connected():
|
||||
from backend.data import platform_cost as _platform_cost_db
|
||||
|
||||
platform_cost_db = _platform_cost_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
platform_cost_db = get_database_manager_async_client()
|
||||
|
||||
return platform_cost_db
|
||||
|
||||
@@ -96,6 +96,7 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.platform_cost import log_platform_cost
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
@@ -332,6 +333,9 @@ class DatabaseManager(AppService):
|
||||
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = _(log_platform_cost)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
@@ -529,6 +533,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Block Descriptions ============ #
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = d.log_platform_cost
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
|
||||
@@ -333,26 +333,29 @@ class BaseGraph(GraphBaseMeta):
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid {type_class}: {input_default}, {e}")
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
**{
|
||||
k: v
|
||||
for k, v in p.generate_schema().items()
|
||||
if k not in ["description", "default"]
|
||||
},
|
||||
"secret": p.secret,
|
||||
# Default value has to be set for advanced fields.
|
||||
"advanced": p.advanced and p.value is not None,
|
||||
"title": p.title or p.name,
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
try:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
**{
|
||||
k: v
|
||||
for k, v in p.generate_schema().items()
|
||||
if k not in ["description", "default"]
|
||||
},
|
||||
"secret": p.secret,
|
||||
# Default value has to be set for advanced fields.
|
||||
"advanced": p.advanced and p.value is not None,
|
||||
"title": p.title or p.name,
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
except AttributeError as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class GraphTriggerInfo(BaseModel):
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
GraphModel,
|
||||
Link,
|
||||
Node,
|
||||
get_graph,
|
||||
@@ -1460,3 +1461,21 @@ async def test_validate_graph_execution_permissions_library_wrong_version_denied
|
||||
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
|
||||
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
|
||||
assert lib_where["agentGraphVersion"] == graph_version
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _generate_schema AttributeError → ValueError conversion
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_generate_schema_raises_value_error_when_name_missing():
|
||||
"""AgentInputBlock.Input constructed without 'name' should raise ValueError.
|
||||
|
||||
model_construct() skips validation, so the Input object is created without
|
||||
a 'name' attribute. The dict comprehension in _generate_schema then hits an
|
||||
AttributeError when it accesses p.name. That AttributeError must be caught
|
||||
and re-raised as ValueError so the existing 400 handler in rest_api.py fires
|
||||
instead of falling through to the 500 catch-all.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
GraphModel._generate_schema((AgentInputBlock.Input, {}))
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import (
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||
from prisma.enums import CreditTransactionType, OnboardingStep, SubscriptionTier
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -54,7 +54,6 @@ class User(BaseModel):
|
||||
"""Application-layer User model with snake_case convention."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
@@ -104,6 +103,9 @@ class User(BaseModel):
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
# Subscription / rate-limit tier
|
||||
subscription_tier: SubscriptionTier | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -158,6 +160,7 @@ class User(BaseModel):
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
|
||||
subscription_tier=prisma_user.subscriptionTier,
|
||||
)
|
||||
|
||||
|
||||
@@ -819,6 +822,17 @@ class RefundRequest(BaseModel):
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
ProviderCostType = Literal[
|
||||
"cost_usd", # Actual USD cost reported by the provider
|
||||
"tokens", # LLM token counts (sum of input + output)
|
||||
"characters", # Per-character billing (TTS providers)
|
||||
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
|
||||
"walltime_seconds", # Per-second billing incl. queue/polling
|
||||
"per_run", # Per-API-call billing with fixed cost
|
||||
"items", # Per-item billing (lead/organization/result count)
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
@@ -838,32 +852,39 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
# Type of the provider-reported cost/usage captured above. When set
|
||||
# by a block, resolve_tracking honors this directly instead of
|
||||
# guessing from provider name.
|
||||
provider_cost_type: Optional[ProviderCostType] = None
|
||||
# Moderation fields
|
||||
cleared_inputs: Optional[dict[str, list[str]]] = None
|
||||
cleared_outputs: Optional[dict[str, list[str]]] = None
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
"""Mutate this instance by adding another NodeExecutionStats.
|
||||
|
||||
Avoids calling model_dump() twice per merge (called on every
|
||||
merge_stats() from ~20+ blocks); reads via getattr/vars instead.
|
||||
"""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
for key in type(other).model_fields:
|
||||
value = getattr(other, key)
|
||||
if value is None:
|
||||
# Never overwrite an existing value with None
|
||||
continue
|
||||
current = getattr(self, key, None)
|
||||
if current is None:
|
||||
# Field doesn't exist yet or is None, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, dict) and isinstance(current, dict):
|
||||
current.update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(current, (int, float)):
|
||||
setattr(self, key, current + value)
|
||||
elif isinstance(value, list) and isinstance(current, list):
|
||||
current.extend(value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.data.model import HostScopedCredentials, NodeExecutionStats
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
@@ -166,3 +166,84 @@ class TestHostScopedCredentials:
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
|
||||
|
||||
class TestNodeExecutionStatsIadd:
|
||||
def test_adds_numeric_fields(self):
|
||||
a = NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
b = NodeExecutionStats(input_token_count=200, output_token_count=30)
|
||||
a += b
|
||||
assert a.input_token_count == 300
|
||||
assert a.output_token_count == 80
|
||||
|
||||
def test_none_does_not_overwrite(self):
|
||||
a = NodeExecutionStats(provider_cost=0.5, error="some error")
|
||||
b = NodeExecutionStats(provider_cost=None, error=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.5
|
||||
assert a.error == "some error"
|
||||
|
||||
def test_none_is_skipped_preserving_existing_value(self):
|
||||
a = NodeExecutionStats(input_token_count=100)
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.input_token_count == 100
|
||||
|
||||
def test_dict_fields_are_merged(self):
|
||||
a = NodeExecutionStats(
|
||||
cleared_inputs={"field1": ["val1"]},
|
||||
)
|
||||
b = NodeExecutionStats(
|
||||
cleared_inputs={"field2": ["val2"]},
|
||||
)
|
||||
a += b
|
||||
assert a.cleared_inputs == {"field1": ["val1"], "field2": ["val2"]}
|
||||
|
||||
def test_returns_self(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(input_token_count=10)
|
||||
result = a.__iadd__(b)
|
||||
assert result is a
|
||||
|
||||
def test_not_implemented_for_non_stats(self):
|
||||
a = NodeExecutionStats()
|
||||
result = a.__iadd__("not a stats") # type: ignore[arg-type]
|
||||
assert result is NotImplemented
|
||||
|
||||
def test_error_none_does_not_clear_existing_error(self):
|
||||
a = NodeExecutionStats(error="existing error")
|
||||
b = NodeExecutionStats(error=None)
|
||||
a += b
|
||||
assert a.error == "existing error"
|
||||
|
||||
def test_provider_cost_none_does_not_clear_existing_cost(self):
|
||||
a = NodeExecutionStats(provider_cost=0.05)
|
||||
b = NodeExecutionStats(provider_cost=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_accumulates_when_both_set(self):
|
||||
a = NodeExecutionStats(provider_cost=0.01)
|
||||
b = NodeExecutionStats(provider_cost=0.02)
|
||||
a += b
|
||||
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
|
||||
|
||||
def test_provider_cost_first_write_from_none(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost=0.05)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_type_first_write_from_none(self):
|
||||
"""Writing provider_cost_type into a stats with None sets it."""
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost_type="characters")
|
||||
a += b
|
||||
assert a.provider_cost_type == "characters"
|
||||
|
||||
def test_provider_cost_type_none_does_not_overwrite(self):
|
||||
"""A None provider_cost_type from other must not clear an existing value."""
|
||||
a = NodeExecutionStats(provider_cost_type="tokens")
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.provider_cost_type == "tokens"
|
||||
|
||||
378
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
378
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
@@ -0,0 +1,378 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.types import PlatformCostLogCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MICRODOLLARS_PER_USD = 1_000_000
|
||||
|
||||
# Dashboard query limits — keep in sync with the SQL queries below
|
||||
MAX_PROVIDER_ROWS = 500
|
||||
MAX_USER_ROWS = 100
|
||||
|
||||
# Default date range for dashboard queries when no start date is provided.
|
||||
# Prevents full-table scans on large deployments.
|
||||
DEFAULT_DASHBOARD_DAYS = 30
|
||||
|
||||
|
||||
def usd_to_microdollars(cost_usd: float | None) -> int | None:
|
||||
"""Convert a USD amount (float) to microdollars (int). None-safe."""
|
||||
if cost_usd is None:
|
||||
return None
|
||||
return round(cost_usd * MICRODOLLARS_PER_USD)
|
||||
|
||||
|
||||
class PlatformCostEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block_name: str | None = None
|
||||
provider: str
|
||||
credential_id: str | None = None
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
data_size: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
tracking_type: str | None = None
|
||||
tracking_amount: float | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
async def log_platform_cost(entry: PlatformCostEntry) -> None:
|
||||
await PrismaLog.prisma().create(
|
||||
data=PlatformCostLogCreateInput(
|
||||
userId=entry.user_id,
|
||||
graphExecId=entry.graph_exec_id,
|
||||
nodeExecId=entry.node_exec_id,
|
||||
graphId=entry.graph_id,
|
||||
nodeId=entry.node_id,
|
||||
blockId=entry.block_id,
|
||||
blockName=entry.block_name,
|
||||
# Normalize to lowercase so the (provider, createdAt) index is always
|
||||
# used without LOWER() on the read side.
|
||||
provider=entry.provider.lower(),
|
||||
credentialId=entry.credential_id,
|
||||
costMicrodollars=entry.cost_microdollars,
|
||||
inputTokens=entry.input_tokens,
|
||||
outputTokens=entry.output_tokens,
|
||||
dataSize=entry.data_size,
|
||||
duration=entry.duration,
|
||||
model=entry.model,
|
||||
trackingType=entry.tracking_type,
|
||||
trackingAmount=entry.tracking_amount,
|
||||
metadata=SafeJson(entry.metadata or {}),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Bound the number of concurrent cost-log DB inserts to prevent unbounded
|
||||
# task/connection growth under sustained load or DB slowness.
|
||||
_log_semaphore = asyncio.Semaphore(50)
|
||||
|
||||
|
||||
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
"""Fire-and-forget wrapper that never raises."""
|
||||
try:
|
||||
async with _log_semaphore:
|
||||
await log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
|
||||
def _mask_email(email: str | None) -> str | None:
|
||||
"""Mask an email address to reduce PII exposure in admin API responses.
|
||||
|
||||
Turns 'user@example.com' into 'us***@example.com'.
|
||||
Handles short local parts gracefully (e.g. 'a@b.com' → 'a***@b.com').
|
||||
"""
|
||||
if not email:
|
||||
return email
|
||||
at = email.find("@")
|
||||
if at < 0:
|
||||
return "***"
|
||||
local = email[:at]
|
||||
domain = email[at:]
|
||||
visible = local[:2] if len(local) >= 2 else local[:1]
|
||||
return f"{visible}***{domain}"
|
||||
|
||||
|
||||
class ProviderCostSummary(BaseModel):
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_duration_seconds: float = 0.0
|
||||
total_tracking_amount: float = 0.0
|
||||
request_count: int
|
||||
|
||||
|
||||
class UserCostSummary(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_name: str
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
# Provider names are normalized to lowercase at write time so a plain
|
||||
# equality check is sufficient and the (provider, createdAt) index is used.
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
"""Aggregate platform cost logs for the admin dashboard.
|
||||
|
||||
Note: by_provider rows are keyed on (provider, tracking_type). A single
|
||||
provider can therefore appear in multiple rows if it has entries with
|
||||
different billing models (e.g. "openai" with both "tokens" and "cost_usd"
|
||||
if pricing is later added for some entries). Frontend treats each row
|
||||
independently rather than as a provider primary key.
|
||||
|
||||
Defaults to the last DEFAULT_DASHBOARD_DAYS days when no start date is
|
||||
provided to avoid full-table scans on large deployments.
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_p, params_p = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
|
||||
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider", p."trackingType"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_PROVIDER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_USER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
)
|
||||
|
||||
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
|
||||
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
|
||||
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
|
||||
total_cost = sum(r["total_cost"] for r in by_provider_rows)
|
||||
total_requests = sum(r["request_count"] for r in by_provider_rows)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
total_duration_seconds=r.get("total_duration", 0.0),
|
||||
total_tracking_amount=r.get("total_tracking_amount", 0.0),
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_provider_rows
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_user_rows
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
)
|
||||
|
||||
|
||||
async def get_platform_cost_logs(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
|
||||
count_rows, rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
),
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return logs, total
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Integration tests for platform cost logging.
|
||||
|
||||
These tests run actual database operations to verify that SafeJson metadata
|
||||
round-trips correctly through Prisma — catching the DataError that occurred
|
||||
when a plain Python dict was passed to the Prisma Json? field.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.models import User
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .platform_cost import PlatformCostEntry, log_platform_cost
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cost_log_user():
|
||||
"""Create a throw-away user and clean up cost logs after the test."""
|
||||
user_id = str(uuid.uuid4())
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"cost-test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
)
|
||||
yield user_id
|
||||
await PrismaLog.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete(where={"id": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_log_platform_cost_metadata_round_trip(cost_log_user):
|
||||
"""
|
||||
Verify that SafeJson metadata is persisted and read back correctly.
|
||||
|
||||
This test would have caught the DataError that silently swallowed all cost
|
||||
log writes when a plain Python dict was passed to the Prisma Json? field.
|
||||
"""
|
||||
user_id = cost_log_user
|
||||
entry = PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
block_name="TestBlock",
|
||||
provider="openai",
|
||||
cost_microdollars=5000,
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
model="gpt-4",
|
||||
metadata={"key": "val", "nested": {"x": 1}},
|
||||
)
|
||||
await log_platform_cost(entry)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(where={"userId": user_id})
|
||||
assert len(rows) == 1
|
||||
assert rows[0].metadata == {"key": "val", "nested": {"x": 1}}
|
||||
assert rows[0].provider == "openai"
|
||||
assert rows[0].costMicrodollars == 5000
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_log_platform_cost_metadata_none(cost_log_user):
|
||||
"""Verify that None metadata falls back to {} (not a DataError)."""
|
||||
user_id = cost_log_user
|
||||
entry = PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
block_name="TestBlock",
|
||||
provider="anthropic",
|
||||
metadata=None,
|
||||
)
|
||||
await log_platform_cost(entry)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(where={"userId": user_id})
|
||||
assert len(rows) == 1
|
||||
assert rows[0].metadata == {}
|
||||
286
autogpt_platform/backend/backend/data/platform_cost_test.py
Normal file
286
autogpt_platform/backend/backend/data/platform_cost_test.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Unit tests for helpers and async functions in platform_cost module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma import Json
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
_build_where,
|
||||
_mask_email,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
log_platform_cost,
|
||||
log_platform_cost_safe,
|
||||
)
|
||||
|
||||
|
||||
class TestMaskEmail:
|
||||
def test_typical_email(self):
|
||||
assert _mask_email("user@example.com") == "us***@example.com"
|
||||
|
||||
def test_short_local_part(self):
|
||||
assert _mask_email("a@b.com") == "a***@b.com"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _mask_email(None) is None
|
||||
|
||||
def test_empty_string_returns_empty(self):
|
||||
assert _mask_email("") == ""
|
||||
|
||||
def test_no_at_sign_returns_stars(self):
|
||||
assert _mask_email("notanemail") == "***"
|
||||
|
||||
def test_two_char_local(self):
|
||||
assert _mask_email("ab@domain.org") == "ab***@domain.org"
|
||||
|
||||
|
||||
class TestBuildWhere:
|
||||
def test_no_filters_returns_true(self):
|
||||
sql, params = _build_where(None, None, None, None)
|
||||
assert sql == "TRUE"
|
||||
assert params == []
|
||||
|
||||
def test_start_only(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None)
|
||||
assert '"createdAt" >= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_end_only(self):
|
||||
dt = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(None, dt, None, None)
|
||||
assert '"createdAt" <= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_provider_only(self):
|
||||
# Provider names are normalized to lowercase at write time, so the
|
||||
# filter uses a plain equality check. The input is also lowercased so
|
||||
# "OpenAI" and "openai" both match stored rows.
|
||||
sql, params = _build_where(None, None, "OpenAI", None)
|
||||
assert '"provider" = $1' in sql
|
||||
assert params == ["openai"]
|
||||
|
||||
def test_user_id_only(self):
|
||||
sql, params = _build_where(None, None, None, "user-123")
|
||||
assert '"userId" = $1' in sql
|
||||
assert params == ["user-123"]
|
||||
|
||||
def test_all_filters(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(start, end, "Anthropic", "u1")
|
||||
assert "$1" in sql
|
||||
assert "$2" in sql
|
||||
assert "$3" in sql
|
||||
assert "$4" in sql
|
||||
assert len(params) == 4
|
||||
# Provider is lowercased at filter time to match stored lowercase values.
|
||||
assert params == [start, end, "anthropic", "u1"]
|
||||
|
||||
def test_table_alias(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None, table_alias="p")
|
||||
assert 'p."createdAt"' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_clauses_joined_with_and(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, _ = _build_where(start, end, None, None)
|
||||
assert " AND " in sql
|
||||
|
||||
|
||||
def _make_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"block_id": "block-1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"credential_id": "cred-1",
|
||||
**overrides,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestLogPlatformCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_prisma_record(self):
|
||||
mock_create = AsyncMock()
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = mock_create
|
||||
entry = _make_entry(
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cost_microdollars=5000,
|
||||
model="gpt-4",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
await log_platform_cost(entry)
|
||||
mock_create.assert_awaited_once()
|
||||
data = mock_create.call_args[1]["data"]
|
||||
assert data["userId"] == "user-1"
|
||||
assert data["blockName"] == "TestBlock"
|
||||
assert data["provider"] == "openai"
|
||||
# metadata must be wrapped in SafeJson (a prisma.Json subclass), not a plain dict
|
||||
assert isinstance(data["metadata"], Json)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_none_passes_none(self):
|
||||
mock_create = AsyncMock()
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = mock_create
|
||||
entry = _make_entry(metadata=None)
|
||||
await log_platform_cost(entry)
|
||||
data = mock_create.call_args[1]["data"]
|
||||
# None falls back to SafeJson({}) so Prisma always gets a valid Json value
|
||||
assert isinstance(data["metadata"], Json)
|
||||
assert data["metadata"] == SafeJson({})
|
||||
|
||||
|
||||
class TestLogPlatformCostSafe:
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_on_error(self):
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=RuntimeError("DB down")
|
||||
)
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_succeeds_when_no_error(self):
|
||||
mock_create = AsyncMock()
|
||||
with patch("backend.data.platform_cost.PrismaLog.prisma") as mock_prisma:
|
||||
mock_prisma.return_value.create = mock_create
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
def setup_method(self):
|
||||
# @cached stores results in-process; clear between tests to avoid bleed.
|
||||
get_platform_cost_dashboard.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dashboard_with_data(self):
|
||||
provider_rows = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"total_duration": 10.5,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
user_rows = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
|
||||
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 5000
|
||||
assert dashboard.total_requests == 3
|
||||
assert dashboard.total_users == 1
|
||||
assert len(dashboard.by_provider) == 1
|
||||
assert dashboard.by_provider[0].provider == "openai"
|
||||
assert dashboard.by_provider[0].tracking_type == "tokens"
|
||||
assert dashboard.by_provider[0].total_duration_seconds == 10.5
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 0
|
||||
assert dashboard.total_requests == 0
|
||||
assert dashboard.total_users == 0
|
||||
assert dashboard.by_provider == []
|
||||
assert dashboard.by_user == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
assert mock_query.await_count == 3
|
||||
first_call_sql = mock_query.call_args_list[0][0][0]
|
||||
assert "createdAt" in first_call_sql
|
||||
|
||||
|
||||
class TestGetPlatformCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_and_total(self):
|
||||
count_rows = [{"cnt": 1}]
|
||||
log_rows = [
|
||||
{
|
||||
"id": "log-1",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"graph_exec_id": "g1",
|
||||
"node_exec_id": "n1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 5000,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"duration": 1.5,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=1, page_size=10)
|
||||
assert total == 1
|
||||
assert len(logs) == 1
|
||||
assert logs[0].id == "log-1"
|
||||
assert logs[0].provider == "openai"
|
||||
assert logs[0].model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_data(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
assert logs == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_offset(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
assert total == 100
|
||||
second_call_args = mock_query.call_args_list[1][0]
|
||||
assert 25 in second_call_args # page_size
|
||||
assert 50 in second_call_args # offset = (3-1) * 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_count_returns_zero(self):
|
||||
mock_query = AsyncMock(side_effect=[[], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
291
autogpt_platform/backend/backend/executor/cost_tracking.py
Normal file
291
autogpt_platform/backend/backend/executor/cost_tracking.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Helpers for platform cost tracking on system-credential block executions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.copilot.token_tracking import _pending_log_tasks as _copilot_tasks
|
||||
from backend.copilot.token_tracking import (
|
||||
_pending_log_tasks_lock as _copilot_tasks_lock,
|
||||
)
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import is_system_credential
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider groupings by billing model — used when the block didn't explicitly
|
||||
# declare stats.provider_cost_type and we fall back to provider-name
|
||||
# heuristics. Values match ProviderName enum values.
|
||||
_CHARACTER_BILLED_PROVIDERS = frozenset(
|
||||
{ProviderName.D_ID.value, ProviderName.ELEVENLABS.value}
|
||||
)
|
||||
_WALLTIME_BILLED_PROVIDERS = frozenset(
|
||||
{
|
||||
ProviderName.FAL.value,
|
||||
ProviderName.REVID.value,
|
||||
ProviderName.REPLICATE.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Hold strong references to in-flight log tasks so the event loop doesn't
|
||||
# garbage-collect them mid-execution. Tasks remove themselves on completion.
|
||||
# _pending_log_tasks_lock guards all reads and writes: worker threads call
|
||||
# discard() via done callbacks while drain_pending_cost_logs() iterates.
|
||||
_pending_log_tasks: set[asyncio.Task] = set()
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads. Key by loop instance
|
||||
# so each executor worker thread gets its own semaphore.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
"""Await all in-flight cost log tasks with a timeout.
|
||||
|
||||
Drains both the executor cost log tasks (_pending_log_tasks in this module,
|
||||
used for block execution cost tracking via DatabaseManagerAsyncClient) and
|
||||
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
|
||||
copilot LLM turns via platform_cost_db()).
|
||||
|
||||
Call this during graceful shutdown to flush pending INSERT tasks before
|
||||
the process exits. Tasks that don't complete within `timeout` seconds are
|
||||
abandoned and their failures are already logged by _safe_log.
|
||||
"""
|
||||
# asyncio.wait() requires all tasks to belong to the running event loop.
|
||||
# _pending_log_tasks is shared across executor worker threads (each with
|
||||
# its own loop), so filter to only tasks owned by the current loop.
|
||||
# Acquire the lock to take a consistent snapshot (worker threads call
|
||||
# discard() via done callbacks concurrently with this iteration).
|
||||
current_loop = asyncio.get_running_loop()
|
||||
with _pending_log_tasks_lock:
|
||||
all_pending = [t for t in _pending_log_tasks if t.get_loop() is current_loop]
|
||||
if all_pending:
|
||||
logger.info("Draining %d executor cost log task(s)", len(all_pending))
|
||||
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d executor cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
|
||||
with _copilot_tasks_lock:
|
||||
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
|
||||
if copilot_pending:
|
||||
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
|
||||
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d copilot cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def _schedule_log(
|
||||
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
|
||||
) -> None:
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await db_client.log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
def _extract_model_name(raw: str | dict | None) -> str | None:
|
||||
"""Return a string model name from a block input field, or None.
|
||||
|
||||
Handles str (returned as-is), dict (e.g. an enum wrapper, skipped), and
|
||||
None (no model field). Unexpected types are coerced to str as a fallback.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
return None
|
||||
return str(raw)
|
||||
|
||||
|
||||
def resolve_tracking(
|
||||
provider: str,
|
||||
stats: NodeExecutionStats,
|
||||
input_data: dict[str, Any],
|
||||
) -> tuple[str, float]:
|
||||
"""Return (tracking_type, tracking_amount) based on provider billing model.
|
||||
|
||||
Preference order:
|
||||
1. Block-declared: if the block set `provider_cost_type` on its stats,
|
||||
honor it directly (paired with `provider_cost` as the amount).
|
||||
2. Heuristic fallback: infer from `provider_cost`/token counts, then
|
||||
from provider name for per-character / per-second billing.
|
||||
"""
|
||||
# 1. Block explicitly declared its cost type (only when an amount is present)
|
||||
if stats.provider_cost_type and stats.provider_cost is not None:
|
||||
return stats.provider_cost_type, max(0.0, stats.provider_cost)
|
||||
|
||||
# 2. Provider returned actual USD cost (OpenRouter, Exa)
|
||||
if stats.provider_cost is not None:
|
||||
return "cost_usd", max(0.0, stats.provider_cost)
|
||||
|
||||
# 3. LLM providers: track by tokens
|
||||
if stats.input_token_count or stats.output_token_count:
|
||||
return "tokens", float(
|
||||
(stats.input_token_count or 0) + (stats.output_token_count or 0)
|
||||
)
|
||||
|
||||
# 4. Provider-specific billing heuristics
|
||||
|
||||
# TTS: billed per character of input text
|
||||
if provider == ProviderName.UNREAL_SPEECH.value:
|
||||
text = input_data.get("text", "")
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# D-ID + ElevenLabs voice: billed per character of script
|
||||
if provider in _CHARACTER_BILLED_PROVIDERS:
|
||||
text = (
|
||||
input_data.get("script_input", "")
|
||||
or input_data.get("text", "")
|
||||
or input_data.get("script", "") # VideoNarrationBlock uses `script`
|
||||
)
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# E2B: billed per second of sandbox time
|
||||
if provider == ProviderName.E2B.value:
|
||||
return "sandbox_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Video/image gen: walltime includes queue + generation + polling
|
||||
if provider in _WALLTIME_BILLED_PROVIDERS:
|
||||
return "walltime_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
|
||||
# All billed per API call - count 1 per block execution.
|
||||
return "per_run", 1.0
|
||||
|
||||
|
||||
async def log_system_credential_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
block: Block,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
) -> None:
|
||||
"""Check if a system credential was used and log the platform cost.
|
||||
|
||||
Routes through DatabaseManagerAsyncClient so the write goes via the
|
||||
message-passing DB service rather than calling Prisma directly (which
|
||||
is not connected in the executor process).
|
||||
|
||||
Logs only the first matching system credential field (one log per
|
||||
execution). Any unexpected error is caught and logged — cost logging
|
||||
is strictly best-effort and must never disrupt block execution.
|
||||
|
||||
Note: costMicrodollars is left null for providers that don't return
|
||||
a USD cost. The credit_cost in metadata captures our internal credit
|
||||
charge as a proxy.
|
||||
"""
|
||||
try:
|
||||
if node_exec.execution_context.dry_run:
|
||||
return
|
||||
|
||||
input_data = node_exec.inputs
|
||||
input_model = cast(type[BlockSchema], block.input_schema)
|
||||
|
||||
for field_name in input_model.get_credentials_fields():
|
||||
cred_data = input_data.get(field_name)
|
||||
if not cred_data or not isinstance(cred_data, dict):
|
||||
continue
|
||||
cred_id = cred_data.get("id", "")
|
||||
if not cred_id or not is_system_credential(cred_id):
|
||||
continue
|
||||
|
||||
model_name = _extract_model_name(input_data.get("model"))
|
||||
|
||||
credit_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
|
||||
provider_name = cred_data.get("provider", "unknown")
|
||||
tracking_type, tracking_amount = resolve_tracking(
|
||||
provider=provider_name,
|
||||
stats=stats,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Only treat provider_cost as USD when the tracking type says so.
|
||||
# For other types (items, characters, per_run, ...) the
|
||||
# provider_cost field holds the raw amount, not a dollar value.
|
||||
# Use tracking_amount (the normalized value from resolve_tracking)
|
||||
# rather than raw stats.provider_cost to avoid unit mismatches.
|
||||
cost_microdollars = None
|
||||
if tracking_type == "cost_usd":
|
||||
cost_microdollars = usd_to_microdollars(tracking_amount)
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
}
|
||||
if credit_cost is not None:
|
||||
meta["credit_cost"] = credit_cost
|
||||
if stats.provider_cost is not None:
|
||||
# Use 'provider_cost_raw' — the value's unit varies by tracking
|
||||
# type (USD for cost_usd, count for items/characters/per_run, etc.)
|
||||
meta["provider_cost_raw"] = stats.provider_cost
|
||||
|
||||
_schedule_log(
|
||||
db_client,
|
||||
PlatformCostEntry(
|
||||
user_id=node_exec.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block_name=block.name,
|
||||
provider=provider_name,
|
||||
credential_id=cred_id,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=stats.input_token_count,
|
||||
output_tokens=stats.output_token_count,
|
||||
data_size=stats.output_size if stats.output_size > 0 else None,
|
||||
duration=stats.walltime if stats.walltime > 0 else None,
|
||||
model=model_name,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata=meta,
|
||||
),
|
||||
)
|
||||
return # One log per execution is enough
|
||||
except Exception:
|
||||
logger.exception("log_system_credential_cost failed unexpectedly")
|
||||
@@ -45,6 +45,10 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util import json
|
||||
@@ -303,9 +307,18 @@ async def execute_node(
|
||||
|
||||
# Handle regular credentials fields
|
||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||
# Dry-run platform credentials bypass the credential store
|
||||
# Dry-run platform credentials bypass the credential store.
|
||||
# Keep the existing credential metadata so _execute's input_schema(**...)
|
||||
# doesn't fail on the required field. If no metadata is present,
|
||||
# synthesize a minimal placeholder from the platform credentials.
|
||||
if _dry_run_creds is not None:
|
||||
input_data[field_name] = None
|
||||
if input_data.get(field_name) is None:
|
||||
input_data[field_name] = {
|
||||
"id": _dry_run_creds.id,
|
||||
"provider": _dry_run_creds.provider,
|
||||
"type": _dry_run_creds.type,
|
||||
"title": _dry_run_creds.title,
|
||||
}
|
||||
extra_exec_kwargs[field_name] = _dry_run_creds
|
||||
continue
|
||||
|
||||
@@ -692,6 +705,15 @@ class ExecutionProcessor:
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
# Log platform cost if system credentials were used (only on success)
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(
|
||||
node_exec=node_exec,
|
||||
block=node.block,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
@async_time_measured
|
||||
@@ -2044,6 +2066,18 @@ class ExecutionManager(AppProcess):
|
||||
prefix + " [cancel-consumer]",
|
||||
)
|
||||
|
||||
# Drain any in-flight cost log tasks before exit so we don't silently
|
||||
# drop INSERT operations during deployments.
|
||||
loop = getattr(self, "node_execution_loop", None)
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
drain_pending_cost_logs(), loop
|
||||
).result(timeout=10)
|
||||
logger.info(f"{prefix} ✅ Cost log tasks drained")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to drain cost log tasks: {e}")
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
@@ -0,0 +1,623 @@
|
||||
"""Unit tests for resolve_tracking and log_system_credential_cost."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext, NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
resolve_tracking,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveTracking:
|
||||
def _stats(self, **overrides: Any) -> NodeExecutionStats:
|
||||
return NodeExecutionStats(**overrides)
|
||||
|
||||
def test_provider_cost_returns_cost_usd(self):
|
||||
stats = self._stats(provider_cost=0.0042)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0042
|
||||
|
||||
def test_token_counts_return_tokens(self):
|
||||
stats = self._stats(input_token_count=300, output_token_count=100)
|
||||
tt, amt = resolve_tracking("anthropic", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 400.0
|
||||
|
||||
def test_token_counts_only_input(self):
|
||||
stats = self._stats(input_token_count=500)
|
||||
tt, amt = resolve_tracking("groq", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 500.0
|
||||
|
||||
def test_unreal_speech_returns_characters(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": "Hello world"})
|
||||
assert tt == "characters"
|
||||
assert amt == 11.0
|
||||
|
||||
def test_unreal_speech_empty_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": ""})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_unreal_speech_non_string_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": 123})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_d_id_uses_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("d_id", stats, {"script_input": "Hello"})
|
||||
assert tt == "characters"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_elevenlabs_uses_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Say this"})
|
||||
assert tt == "characters"
|
||||
assert amt == 8.0
|
||||
|
||||
def test_elevenlabs_fallback_to_text_when_no_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Fallback text"})
|
||||
assert tt == "characters"
|
||||
assert amt == 13.0
|
||||
|
||||
def test_elevenlabs_uses_script_field(self):
|
||||
"""VideoNarrationBlock (elevenlabs) uses `script` field, not script_input/text."""
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"script": "Narration"})
|
||||
assert tt == "characters"
|
||||
assert amt == 9.0
|
||||
|
||||
def test_block_declared_cost_type_items(self):
|
||||
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
|
||||
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "items"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_block_declared_cost_type_characters(self):
|
||||
"""TTS block can declare characters directly, bypassing input_data lookup."""
|
||||
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {})
|
||||
assert tt == "characters"
|
||||
assert amt == 42.0
|
||||
|
||||
def test_block_declared_cost_type_wins_over_tokens(self):
|
||||
"""provider_cost_type takes precedence over token-based heuristic."""
|
||||
stats = self._stats(
|
||||
provider_cost=1.0,
|
||||
provider_cost_type="per_run",
|
||||
input_token_count=500,
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_e2b_returns_sandbox_seconds(self):
|
||||
stats = self._stats(walltime=45.123)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 45.123
|
||||
|
||||
def test_e2b_no_walltime(self):
|
||||
stats = self._stats(walltime=0)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_fal_returns_walltime(self):
|
||||
stats = self._stats(walltime=12.5)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 12.5
|
||||
|
||||
def test_revid_returns_walltime(self):
|
||||
stats = self._stats(walltime=60.0)
|
||||
tt, amt = resolve_tracking("revid", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 60.0
|
||||
|
||||
def test_replicate_returns_walltime(self):
|
||||
stats = self._stats(walltime=30.0)
|
||||
tt, amt = resolve_tracking("replicate", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 30.0
|
||||
|
||||
def test_unknown_provider_returns_per_run(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_negative_provider_cost_clamped_to_zero(self):
|
||||
"""Negative provider_cost values must be clamped to 0."""
|
||||
stats = self._stats(provider_cost=-0.005)
|
||||
tt, amt = resolve_tracking("openrouter", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_negative_block_declared_cost_clamped_to_zero(self):
|
||||
"""Negative block-declared cost must also be clamped to 0."""
|
||||
stats = self._stats(provider_cost=-1.0, provider_cost_type="items")
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "items"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_provider_cost_takes_precedence_over_tokens(self):
|
||||
stats = self._stats(
|
||||
provider_cost=0.01, input_token_count=500, output_token_count=200
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.01
|
||||
|
||||
def test_provider_cost_zero_is_not_none(self):
|
||||
"""provider_cost=0.0 is falsy but should still be tracked as cost_usd
|
||||
(e.g. free-tier or fully-cached responses from OpenRouter)."""
|
||||
stats = self._stats(provider_cost=0.0)
|
||||
tt, amt = resolve_tracking("open_router", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_tokens_take_precedence_over_provider_specific(self):
|
||||
stats = self._stats(input_token_count=100, walltime=10.0)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 100.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# log_system_credential_cost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_client() -> MagicMock:
|
||||
db_client = MagicMock()
|
||||
db_client.log_platform_cost = AsyncMock()
|
||||
return db_client
|
||||
|
||||
|
||||
def _make_block(has_credentials: bool = True) -> MagicMock:
|
||||
block = MagicMock()
|
||||
block.name = "TestBlock"
|
||||
input_schema = MagicMock()
|
||||
if has_credentials:
|
||||
input_schema.get_credentials_fields.return_value = {"credentials": MagicMock()}
|
||||
else:
|
||||
input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema = input_schema
|
||||
return block
|
||||
|
||||
|
||||
def _make_node_exec(
|
||||
inputs: dict | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> NodeExecutionEntry:
|
||||
return NodeExecutionEntry(
|
||||
user_id="user-1",
|
||||
graph_exec_id="gx-1",
|
||||
graph_id="g-1",
|
||||
graph_version=1,
|
||||
node_exec_id="nx-1",
|
||||
node_id="n-1",
|
||||
block_id="b-1",
|
||||
inputs=inputs or {},
|
||||
execution_context=ExecutionContext(dry_run=dry_run),
|
||||
)
|
||||
|
||||
|
||||
class TestLogSystemCredentialCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_dry_run(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(dry_run=True)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_no_credential_fields(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block(has_credentials=False)
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_cred_data_missing(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_not_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with patch(
|
||||
"backend.executor.cost_tracking.is_system_credential",
|
||||
return_value=False,
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "user-cred-123", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(10, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-1", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.user_id == "user-1"
|
||||
assert entry.provider == "openai"
|
||||
assert entry.block_name == "TestBlock"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.input_tokens == 500
|
||||
assert entry.output_tokens == 200
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 700.0
|
||||
assert entry.metadata["credit_cost"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_provider_cost(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-2", "provider": "open_router"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["provider_cost_raw"] == 0.0015
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_enum_converted_to_str(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
from enum import Enum
|
||||
|
||||
class FakeModel(Enum):
|
||||
GPT4 = "gpt-4"
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": FakeModel.GPT4,
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model == "FakeModel.GPT4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_dict_becomes_none(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": {"nested": "value"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_when_block_usage_cost_raises(self):
|
||||
"""log_system_credential_cost must swallow exceptions from block_usage_cost."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
side_effect=RuntimeError("pricing lookup failed"),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
# Should not raise — outer except must catch block_usage_cost error
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_instead_of_int_for_microdollars(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
# 0.0015 * 1_000_000 = 1499.9999999... with float math
|
||||
# round() should give 1500, int() would give 1499
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_run_metadata_has_no_provider_cost_raw(self):
|
||||
"""For per-run providers (google_maps etc), provider_cost_raw is absent
|
||||
from metadata since stats.provider_cost is None."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "google_maps"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats() # no provider_cost
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.tracking_type == "per_run"
|
||||
assert "provider_cost_raw" not in (entry.metadata or {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# merge_stats accumulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeStats:
|
||||
"""Tests for NodeExecutionStats accumulation via += (used by Block.merge_stats)."""
|
||||
|
||||
def test_accumulates_output_size(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
stats += NodeExecutionStats(output_size=25)
|
||||
assert stats.output_size == 35
|
||||
|
||||
def test_accumulates_tokens(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
stats += NodeExecutionStats(input_token_count=200, output_token_count=150)
|
||||
assert stats.input_token_count == 300
|
||||
assert stats.output_token_count == 200
|
||||
|
||||
def test_preserves_provider_cost(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.005)
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
assert stats.provider_cost == 0.005
|
||||
assert stats.output_size == 10
|
||||
|
||||
def test_provider_cost_accumulates(self):
|
||||
"""Multiple merge_stats with provider_cost should sum (multi-round
|
||||
tool-calling in copilot / retries can report cost separately)."""
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.001)
|
||||
stats += NodeExecutionStats(provider_cost=0.002)
|
||||
stats += NodeExecutionStats(provider_cost=0.003)
|
||||
assert stats.provider_cost == pytest.approx(0.006)
|
||||
|
||||
def test_provider_cost_none_does_not_overwrite(self):
|
||||
"""A None provider_cost must not wipe a previously-set value."""
|
||||
stats = NodeExecutionStats(provider_cost=0.01)
|
||||
stats += NodeExecutionStats() # provider_cost=None by default
|
||||
assert stats.provider_cost == 0.01
|
||||
|
||||
def test_provider_cost_type_last_write_wins(self):
|
||||
"""provider_cost_type is a Literal — last set value wins on merge."""
|
||||
stats = NodeExecutionStats(provider_cost_type="tokens")
|
||||
stats += NodeExecutionStats(provider_cost_type="items")
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_node_execution -> log_system_credential_cost integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManagerCostTrackingIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_called_with_accumulated_stats(self):
|
||||
"""Verify that log_system_credential_cost receives stats that could
|
||||
have been accumulated by merge_stats across multiple yield steps."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10, input_token_count=100)
|
||||
stats += NodeExecutionStats(output_size=25, input_token_count=200)
|
||||
|
||||
assert stats.output_size == 35
|
||||
assert stats.input_token_count == 300
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-acc", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.input_tokens == 300
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 300.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_status_is_failed(self):
|
||||
"""Manager only calls log_system_credential_cost on COMPLETED status.
|
||||
|
||||
This test verifies the guard condition `if status == COMPLETED` directly:
|
||||
calling log_system_credential_cost only happens on success, never on
|
||||
FAILED or ERROR executions.
|
||||
"""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(
|
||||
inputs={"credentials": {"id": "sys-cred", "provider": "openai"}}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=100)
|
||||
|
||||
# Simulate the manager guard: only call on COMPLETED
|
||||
status = ExecutionStatus.FAILED
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# drain_pending_cost_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDrainPendingCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_empty_set_completes(self):
|
||||
"""drain_pending_cost_logs should succeed silently with no pending tasks."""
|
||||
# Ensure both pending task sets are empty before calling drain
|
||||
import backend.copilot.token_tracking as tt
|
||||
import backend.executor.cost_tracking as ct
|
||||
|
||||
ct._pending_log_tasks.clear()
|
||||
tt._pending_log_tasks.clear()
|
||||
# Should not raise
|
||||
await drain_pending_cost_logs(timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_awaits_in_flight_tasks(self):
|
||||
"""drain_pending_cost_logs waits for tasks on the current loop."""
|
||||
import backend.executor.cost_tracking as ct
|
||||
|
||||
finished = []
|
||||
|
||||
async def _slow():
|
||||
await asyncio.sleep(0)
|
||||
finished.append(1)
|
||||
|
||||
task = asyncio.ensure_future(_slow())
|
||||
with ct._pending_log_tasks_lock:
|
||||
ct._pending_log_tasks.add(task)
|
||||
task.add_done_callback(lambda t: ct._pending_log_tasks.discard(t))
|
||||
|
||||
await drain_pending_cost_logs(timeout=2.0)
|
||||
assert finished == [1], "drain_pending_cost_logs should have awaited the task"
|
||||
@@ -18,6 +18,7 @@ from backend.executor.simulator import (
|
||||
_truncate_input_values,
|
||||
_truncate_value,
|
||||
build_simulation_prompt,
|
||||
get_dry_run_credentials,
|
||||
prepare_dry_run,
|
||||
simulate_block,
|
||||
)
|
||||
@@ -234,6 +235,42 @@ class TestPrepareDryRun:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetDryRunCredentials:
|
||||
"""get_dry_run_credentials pops _dry_run_api_key and returns APIKeyCredentials.
|
||||
|
||||
The returned object must have fields that can be serialised into a valid
|
||||
CredentialsMetaInput placeholder dict for manager.py's schema-construction fix
|
||||
(Bug: manager.py nullified input_data[field_name] = None, which caused
|
||||
_execute's input_schema(**...) to fail because required credential fields were
|
||||
missing after the None-filter pass).
|
||||
"""
|
||||
|
||||
def test_returns_credentials_when_key_present(self) -> None:
|
||||
input_data = {"_dry_run_api_key": "sk-or-test", "other": "val"}
|
||||
creds = get_dry_run_credentials(input_data)
|
||||
assert creds is not None
|
||||
assert creds.api_key.get_secret_value() == "sk-or-test"
|
||||
# key is consumed from input_data
|
||||
assert "_dry_run_api_key" not in input_data
|
||||
|
||||
def test_returns_none_when_key_absent(self) -> None:
|
||||
input_data: dict = {"other": "val"}
|
||||
creds = get_dry_run_credentials(input_data)
|
||||
assert creds is None
|
||||
|
||||
def test_credentials_have_metadata_fields_for_placeholder(self) -> None:
|
||||
"""The returned credentials must have id, provider, type, and title so
|
||||
manager.py can synthesise a valid CredentialsMetaInput placeholder."""
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
creds = get_dry_run_credentials({"_dry_run_api_key": "sk-or-test"})
|
||||
assert creds is not None
|
||||
assert creds.id == "dry-run-platform"
|
||||
assert creds.provider == ProviderName.OPEN_ROUTER
|
||||
assert creds.type == "api_key"
|
||||
assert creds.title is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# simulate_block – input/output passthrough
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionWithNodes
|
||||
from backend.data.model import User
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -473,6 +476,104 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Regression test: RPC layer returns typed User model, not raw dict
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_via_rpc_returns_typed_user(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Regression test: `add_graph_execution` accesses `user.timezone` on the User
|
||||
returned by `get_user_by_id`. This test verifies the downstream code path
|
||||
completes without AttributeError when `get_user_by_id` returns a proper typed
|
||||
User model. Note: the mock returns a User directly — _get_return deserialization
|
||||
is not exercised here; see TestGetReturn in util/service_test.py for that.
|
||||
"""
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = 1
|
||||
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "exec-id-rpc"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED
|
||||
mock_graph_exec.graph_version = 1
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
mock_queue = mocker.AsyncMock()
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_validate.return_value = (mock_graph, [], {}, set())
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = (
|
||||
False # prisma not connected: uses RPC path instead
|
||||
)
|
||||
|
||||
# The RPC layer (_get_return) deserializes JSON dicts into typed Pydantic models.
|
||||
# The mock simulates what add_graph_execution receives after that deserialization:
|
||||
# a proper User model, not a raw dict.
|
||||
mock_user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
name=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
stripe_customer_id=None,
|
||||
top_up_config=None,
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
mock_db_client = mocker.MagicMock()
|
||||
mock_db_client.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_db_client.get_graph_settings = mocker.AsyncMock(
|
||||
return_value=mocker.MagicMock(
|
||||
human_in_the_loop_safe_mode=False, sensitive_action_safe_mode=False
|
||||
)
|
||||
)
|
||||
mock_db_client.create_graph_execution = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_db_client.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_db_client.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "ws-id"
|
||||
mock_db_client.get_or_create_workspace = mocker.AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_db_client.increment_onboarding_runs = mocker.AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_queue", return_value=mock_queue
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus",
|
||||
return_value=mock_event_bus,
|
||||
)
|
||||
|
||||
# Must not raise AttributeError: 'dict' object has no attribute 'timezone'
|
||||
result = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
assert result == mock_graph_exec
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import httpx
|
||||
import sentry_sdk
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from prisma.errors import DataError, UniqueViolationError
|
||||
@@ -711,16 +712,16 @@ def get_service_client(
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
"""Validate and coerce the RPC result to the expected return type.
|
||||
|
||||
Falls back to the raw result with a warning if validation fails.
|
||||
Falls back to the raw result with a warning and Sentry capture if validation fails.
|
||||
"""
|
||||
if expected_return:
|
||||
try:
|
||||
return expected_return.validate_python(result)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"RPC return type validation failed, using raw result: %s",
|
||||
type(e).__name__,
|
||||
f"RPC return type validation failed for {type(e).__name__}: {e}"
|
||||
)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import cached_property
|
||||
from typing import Any, Protocol, cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from prisma.errors import DataError, UniqueViolationError
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from backend.data.model import User
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -21,6 +25,10 @@ from backend.util.service import (
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
|
||||
class _SupportsGetReturn(Protocol):
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any: ...
|
||||
|
||||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -688,3 +696,46 @@ async def test_health_check_during_shutdown(test_service):
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout):
|
||||
# Connection refused/timeout is also acceptable
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit tests for DynamicClient._get_return
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestGetReturn:
|
||||
"""Direct unit tests for DynamicClient._get_return typed-return contract."""
|
||||
|
||||
def _make_client(self) -> _SupportsGetReturn:
|
||||
return cast(_SupportsGetReturn, get_service_client(ServiceTestClient))
|
||||
|
||||
def test_valid_dict_is_deserialized_to_user_model(self):
|
||||
"""TypeAdapter(User) + valid dict → User model returned with .timezone accessible.
|
||||
|
||||
User.model_config uses extra='ignore' so unknown fields (e.g. new columns added
|
||||
in a newer database-manager deploy) are silently dropped instead of raising
|
||||
ValidationError — making the RPC layer forward-compatible during rolling deploys.
|
||||
"""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
valid_dict = {
|
||||
"id": "user-id",
|
||||
"email": "test@example.com",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"unknown_future_field": "some_value", # simulates a new DB field during deploy
|
||||
}
|
||||
client = self._make_client()
|
||||
adapter = TypeAdapter(User)
|
||||
result = client._get_return(adapter, valid_dict)
|
||||
|
||||
assert isinstance(result, User)
|
||||
assert result.timezone is not None
|
||||
|
||||
def test_invalid_dict_falls_back_to_raw_result(self):
|
||||
"""TypeAdapter(User) + invalid dict (missing required fields) → fallback returns raw dict."""
|
||||
invalid_dict = {"id": "user-id"} # missing email, created_at, updated_at
|
||||
client = self._make_client()
|
||||
adapter = TypeAdapter(User)
|
||||
|
||||
result = client._get_return(adapter, invalid_dict)
|
||||
assert result == invalid_dict
|
||||
|
||||
@@ -155,6 +155,7 @@ class WorkspaceManager:
|
||||
path: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> WorkspaceFile:
|
||||
"""
|
||||
Write file to workspace.
|
||||
@@ -168,6 +169,7 @@ class WorkspaceManager:
|
||||
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||
mime_type: MIME type (auto-detected if not provided)
|
||||
overwrite: Whether to overwrite existing file at path
|
||||
metadata: Optional metadata dict (e.g., origin tracking)
|
||||
|
||||
Returns:
|
||||
Created WorkspaceFile instance
|
||||
@@ -246,6 +248,7 @@ class WorkspaceManager:
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
metadata=metadata,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
if retries > 0:
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformCostLog" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT,
|
||||
"graphExecId" TEXT,
|
||||
"nodeExecId" TEXT,
|
||||
"graphId" TEXT,
|
||||
"nodeId" TEXT,
|
||||
"blockId" TEXT,
|
||||
"blockName" TEXT,
|
||||
"provider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"costMicrodollars" BIGINT,
|
||||
"inputTokens" INTEGER,
|
||||
"outputTokens" INTEGER,
|
||||
"dataSize" INTEGER,
|
||||
"duration" DOUBLE PRECISION,
|
||||
"model" TEXT,
|
||||
"trackingType" TEXT,
|
||||
"trackingAmount" DOUBLE PRECISION,
|
||||
"metadata" JSONB,
|
||||
|
||||
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_userId_createdAt_idx" ON "PlatformCostLog"("userId", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_createdAt_idx" ON "PlatformCostLog"("provider", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -75,6 +75,8 @@ model User {
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
|
||||
PlatformCostLogs PlatformCostLog[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
@@ -815,6 +817,45 @@ model CreditRefundRequest {
|
||||
@@index([userId, transactionKey])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////// Platform Cost Tracking TABLES //////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model PlatformCostLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: SetNull)
|
||||
graphExecId String?
|
||||
nodeExecId String?
|
||||
graphId String?
|
||||
nodeId String?
|
||||
blockId String?
|
||||
blockName String?
|
||||
provider String
|
||||
credentialId String?
|
||||
|
||||
// Cost in microdollars (1 USD = 1,000,000). Null if unknown.
|
||||
costMicrodollars BigInt?
|
||||
|
||||
inputTokens Int?
|
||||
outputTokens Int?
|
||||
dataSize Int? // bytes
|
||||
duration Float? // seconds
|
||||
model String?
|
||||
trackingType String? // e.g. "cost_usd", "tokens", "characters", "items", "per_run", "sandbox_seconds", "walltime_seconds"
|
||||
trackingAmount Float? // Amount in the unit implied by trackingType
|
||||
metadata Json?
|
||||
|
||||
@@index([userId, createdAt])
|
||||
@@index([provider, createdAt])
|
||||
@@index([createdAt])
|
||||
@@index([graphExecId])
|
||||
@@index([provider, trackingType])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// Store TABLES ///////////////////////////
|
||||
|
||||
@@ -66,6 +66,29 @@ describe("useOnboardingWizardStore", () => {
|
||||
"no tests",
|
||||
]);
|
||||
});
|
||||
|
||||
it("ignores new selections when at the max limit", () => {
|
||||
useOnboardingWizardStore.getState().togglePainPoint("a");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("b");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("c");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("d");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
]);
|
||||
});
|
||||
|
||||
it("still allows deselecting when at the max limit", () => {
|
||||
useOnboardingWizardStore.getState().togglePainPoint("a");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("b");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("c");
|
||||
useOnboardingWizardStore.getState().togglePainPoint("b");
|
||||
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
|
||||
"a",
|
||||
"c",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setOtherPainPoint", () => {
|
||||
|
||||
@@ -7,9 +7,9 @@ export function ProgressBar({ currentStep, totalSteps }: Props) {
|
||||
const percent = (currentStep / totalSteps) * 100;
|
||||
|
||||
return (
|
||||
<div className="absolute left-0 top-0 h-[0.625rem] w-full bg-neutral-300">
|
||||
<div className="absolute left-0 top-0 h-[3px] w-full bg-neutral-200">
|
||||
<div
|
||||
className="h-full bg-purple-400 shadow-[0_0_4px_2px_rgba(168,85,247,0.5)] transition-all duration-500 ease-out"
|
||||
className="h-full bg-purple-400 transition-all duration-500 ease-out"
|
||||
style={{ width: `${percent}%` }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Check } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
icon: React.ReactNode;
|
||||
@@ -24,13 +25,18 @@ export function SelectableCard({
|
||||
onClick={onClick}
|
||||
aria-pressed={selected}
|
||||
className={cn(
|
||||
"flex h-[9rem] w-[10.375rem] shrink-0 flex-col items-center justify-center gap-3 rounded-xl border-2 bg-white px-6 py-5 transition-all hover:shadow-sm md:shrink lg:gap-2 lg:px-10 lg:py-8",
|
||||
"relative flex h-[9rem] w-[10.375rem] shrink-0 flex-col items-center justify-center gap-3 rounded-xl border-2 bg-white px-6 py-5 transition-all hover:shadow-sm md:shrink lg:gap-2 lg:px-10 lg:py-8",
|
||||
className,
|
||||
selected
|
||||
? "border-purple-500 bg-purple-50 shadow-sm"
|
||||
: "border-transparent",
|
||||
)}
|
||||
>
|
||||
{selected && (
|
||||
<span className="absolute right-2 top-2 flex h-5 w-5 items-center justify-center rounded-full bg-purple-500">
|
||||
<Check size={12} weight="bold" className="text-white" />
|
||||
</span>
|
||||
)}
|
||||
<Text
|
||||
variant="lead"
|
||||
as="span"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ReactNode } from "react";
|
||||
|
||||
import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
|
||||
@@ -73,6 +74,8 @@ export function PainPointsStep() {
|
||||
togglePainPoint,
|
||||
setOtherPainPoint,
|
||||
hasSomethingElse,
|
||||
atLimit,
|
||||
shaking,
|
||||
canContinue,
|
||||
handleLaunch,
|
||||
} = usePainPointsStep();
|
||||
@@ -90,7 +93,7 @@ export function PainPointsStep() {
|
||||
What's eating your time?
|
||||
</Text>
|
||||
<Text variant="lead" className="!text-zinc-500">
|
||||
Pick the tasks you'd love to hand off to Autopilot
|
||||
Pick the tasks you'd love to hand off to AutoPilot
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -107,11 +110,22 @@ export function PainPointsStep() {
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{!hasSomethingElse ? (
|
||||
<Text variant="small" className="!text-zinc-500">
|
||||
Pick as many as you want — you can always change later
|
||||
</Text>
|
||||
) : null}
|
||||
<Text
|
||||
variant="small"
|
||||
className={cn(
|
||||
"transition-colors",
|
||||
atLimit && canContinue ? "!text-green-600" : "!text-zinc-500",
|
||||
shaking && "animate-shake",
|
||||
)}
|
||||
>
|
||||
{shaking
|
||||
? "You've picked 3 — tap one to swap it out"
|
||||
: atLimit && canContinue
|
||||
? "3 selected — you're all set!"
|
||||
: atLimit && hasSomethingElse
|
||||
? "Tell us what else takes up your time"
|
||||
: "Pick up to 3 to start — AutoPilot can help with anything else later"}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{hasSomethingElse && (
|
||||
@@ -133,7 +147,7 @@ export function PainPointsStep() {
|
||||
disabled={!canContinue}
|
||||
className="w-full max-w-xs"
|
||||
>
|
||||
Launch Autopilot
|
||||
Launch AutoPilot
|
||||
</Button>
|
||||
</div>
|
||||
</FadeIn>
|
||||
|
||||
@@ -8,6 +8,7 @@ import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
|
||||
import { SelectableCard } from "../components/SelectableCard";
|
||||
import { useOnboardingWizardStore } from "../store";
|
||||
import { Emoji } from "@/components/atoms/Emoji/Emoji";
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
const IMG_SIZE = 42;
|
||||
|
||||
@@ -57,12 +58,26 @@ export function RoleStep() {
|
||||
const setRole = useOnboardingWizardStore((s) => s.setRole);
|
||||
const setOtherRole = useOnboardingWizardStore((s) => s.setOtherRole);
|
||||
const nextStep = useOnboardingWizardStore((s) => s.nextStep);
|
||||
const autoAdvanceTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
const isOther = role === "Other";
|
||||
const canContinue = role && (!isOther || otherRole.trim());
|
||||
|
||||
function handleContinue() {
|
||||
if (canContinue) {
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (autoAdvanceTimer.current) clearTimeout(autoAdvanceTimer.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
function handleRoleSelect(id: string) {
|
||||
if (autoAdvanceTimer.current) clearTimeout(autoAdvanceTimer.current);
|
||||
setRole(id);
|
||||
if (id !== "Other") {
|
||||
autoAdvanceTimer.current = setTimeout(nextStep, 350);
|
||||
}
|
||||
}
|
||||
|
||||
function handleOtherContinue() {
|
||||
if (otherRole.trim()) {
|
||||
nextStep();
|
||||
}
|
||||
}
|
||||
@@ -78,7 +93,7 @@ export function RoleStep() {
|
||||
What best describes you, {name}?
|
||||
</Text>
|
||||
<Text variant="lead" className="!text-zinc-500">
|
||||
Autopilot will tailor automations to your world
|
||||
So AutoPilot knows how to help you best
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -89,33 +104,35 @@ export function RoleStep() {
|
||||
icon={r.icon}
|
||||
label={r.label}
|
||||
selected={role === r.id}
|
||||
onClick={() => setRole(r.id)}
|
||||
onClick={() => handleRoleSelect(r.id)}
|
||||
className="p-8"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{isOther && (
|
||||
<div className="-mb-5 w-full px-8 md:px-0">
|
||||
<Input
|
||||
id="other-role"
|
||||
label="Other role"
|
||||
hideLabel
|
||||
placeholder="Describe your role..."
|
||||
value={otherRole}
|
||||
onChange={(e) => setOtherRole(e.target.value)}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<>
|
||||
<div className="-mb-5 w-full px-8 md:px-0">
|
||||
<Input
|
||||
id="other-role"
|
||||
label="Other role"
|
||||
hideLabel
|
||||
placeholder="Describe your role..."
|
||||
value={otherRole}
|
||||
onChange={(e) => setOtherRole(e.target.value)}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={handleContinue}
|
||||
disabled={!canContinue}
|
||||
className="w-full max-w-xs"
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleOtherContinue}
|
||||
disabled={!otherRole.trim()}
|
||||
className="w-full max-w-xs"
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</FadeIn>
|
||||
);
|
||||
|
||||
@@ -4,13 +4,6 @@ import { AutoGPTLogo } from "@/components/atoms/AutoGPTLogo/AutoGPTLogo";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Question } from "@phosphor-icons/react";
|
||||
import { FadeIn } from "@/components/atoms/FadeIn/FadeIn";
|
||||
import { useOnboardingWizardStore } from "../store";
|
||||
|
||||
@@ -40,36 +33,16 @@ export function WelcomeStep() {
|
||||
<Text variant="h3">Welcome to AutoGPT</Text>
|
||||
<Text variant="lead" as="span" className="!text-zinc-500">
|
||||
Let's personalize your experience so{" "}
|
||||
<span className="relative mr-3 inline-block bg-gradient-to-r from-purple-500 to-indigo-500 bg-clip-text text-transparent">
|
||||
Autopilot
|
||||
<span className="absolute -right-4 top-0">
|
||||
<TooltipProvider delayDuration={400}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-label="What is Autopilot?"
|
||||
className="inline-flex text-purple-500"
|
||||
>
|
||||
<Question size={14} />
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Autopilot is AutoGPT's AI assistant that watches your
|
||||
connected apps, spots repetitive tasks you do every day
|
||||
and runs them for you automatically.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</span>
|
||||
<span className="bg-gradient-to-r from-purple-500 to-indigo-500 bg-clip-text text-transparent">
|
||||
AutoPilot
|
||||
</span>{" "}
|
||||
can start saving you time right away
|
||||
can start saving you time
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Input
|
||||
id="first-name"
|
||||
label="Your first name"
|
||||
label="What should I call you?"
|
||||
placeholder="e.g. John"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import { useOnboardingWizardStore } from "../../store";
|
||||
import { PainPointsStep } from "../PainPointsStep";
|
||||
|
||||
vi.mock("@/components/atoms/Emoji/Emoji", () => ({
|
||||
Emoji: ({ text }: { text: string }) => <span>{text}</span>,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/atoms/FadeIn/FadeIn", () => ({
|
||||
FadeIn: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
function getCard(name: RegExp) {
|
||||
return screen.getByRole("button", { name });
|
||||
}
|
||||
|
||||
function clickCard(name: RegExp) {
|
||||
fireEvent.click(getCard(name));
|
||||
}
|
||||
|
||||
function getLaunchButton() {
|
||||
return screen.getByRole("button", { name: /launch autopilot/i });
|
||||
}
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
beforeEach(() => {
|
||||
useOnboardingWizardStore.getState().reset();
|
||||
useOnboardingWizardStore.getState().setName("Alice");
|
||||
useOnboardingWizardStore.getState().setRole("Founder/CEO");
|
||||
useOnboardingWizardStore.getState().goToStep(3);
|
||||
});
|
||||
|
||||
describe("PainPointsStep", () => {
|
||||
test("renders all pain point cards", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
expect(getCard(/finding leads/i)).toBeDefined();
|
||||
expect(getCard(/email & outreach/i)).toBeDefined();
|
||||
expect(getCard(/reports & data/i)).toBeDefined();
|
||||
expect(getCard(/customer support/i)).toBeDefined();
|
||||
expect(getCard(/social media/i)).toBeDefined();
|
||||
expect(getCard(/something else/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows default helper text", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
expect(
|
||||
screen.getAllByText(/pick up to 3 to start/i).length,
|
||||
).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("selecting a card marks it as pressed", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
|
||||
expect(getCard(/finding leads/i).getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
|
||||
test("launch button is disabled when nothing is selected", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
expect(getLaunchButton().hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
test("launch button is enabled after selecting a pain point", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
|
||||
expect(getLaunchButton().hasAttribute("disabled")).toBe(false);
|
||||
});
|
||||
|
||||
test("shows success text when 3 items are selected", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
clickCard(/email & outreach/i);
|
||||
clickCard(/reports & data/i);
|
||||
|
||||
expect(screen.getAllByText(/3 selected/i).length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("does not select a 4th item when at the limit", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
clickCard(/email & outreach/i);
|
||||
clickCard(/reports & data/i);
|
||||
clickCard(/customer support/i);
|
||||
|
||||
expect(getCard(/customer support/i).getAttribute("aria-pressed")).toBe(
|
||||
"false",
|
||||
);
|
||||
});
|
||||
|
||||
test("can deselect when at the limit and select a different one", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
clickCard(/email & outreach/i);
|
||||
clickCard(/reports & data/i);
|
||||
|
||||
clickCard(/finding leads/i);
|
||||
expect(getCard(/finding leads/i).getAttribute("aria-pressed")).toBe(
|
||||
"false",
|
||||
);
|
||||
|
||||
clickCard(/customer support/i);
|
||||
expect(getCard(/customer support/i).getAttribute("aria-pressed")).toBe(
|
||||
"true",
|
||||
);
|
||||
});
|
||||
|
||||
test("shows input when 'Something else' is selected", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/something else/i);
|
||||
|
||||
expect(
|
||||
screen.getByPlaceholderText(/what else takes up your time/i),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
test("launch button is disabled when 'Something else' selected but input empty", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/something else/i);
|
||||
|
||||
expect(getLaunchButton().hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
test("launch button is enabled when 'Something else' selected and input filled", () => {
|
||||
render(<PainPointsStep />);
|
||||
|
||||
clickCard(/something else/i);
|
||||
fireEvent.change(
|
||||
screen.getByPlaceholderText(/what else takes up your time/i),
|
||||
{ target: { value: "Manual invoicing" } },
|
||||
);
|
||||
|
||||
expect(getLaunchButton().hasAttribute("disabled")).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,123 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import { useOnboardingWizardStore } from "../../store";
|
||||
import { RoleStep } from "../RoleStep";
|
||||
|
||||
vi.mock("@/components/atoms/Emoji/Emoji", () => ({
|
||||
Emoji: ({ text }: { text: string }) => <span>{text}</span>,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/atoms/FadeIn/FadeIn", () => ({
|
||||
FadeIn: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
useOnboardingWizardStore.getState().reset();
|
||||
useOnboardingWizardStore.getState().setName("Alice");
|
||||
useOnboardingWizardStore.getState().goToStep(2);
|
||||
});
|
||||
|
||||
describe("RoleStep", () => {
|
||||
test("renders all role cards", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
expect(screen.getByText("Founder / CEO")).toBeDefined();
|
||||
expect(screen.getByText("Operations")).toBeDefined();
|
||||
expect(screen.getByText("Sales / BD")).toBeDefined();
|
||||
expect(screen.getByText("Marketing")).toBeDefined();
|
||||
expect(screen.getByText("Product / PM")).toBeDefined();
|
||||
expect(screen.getByText("Engineering")).toBeDefined();
|
||||
expect(screen.getByText("HR / People")).toBeDefined();
|
||||
expect(screen.getByText("Other")).toBeDefined();
|
||||
});
|
||||
|
||||
test("displays the user name in the heading", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
expect(
|
||||
screen.getAllByText(/what best describes you, alice/i).length,
|
||||
).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("selecting a non-Other role auto-advances after delay", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /engineering/i }));
|
||||
|
||||
expect(useOnboardingWizardStore.getState().role).toBe("Engineering");
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
|
||||
|
||||
vi.advanceTimersByTime(350);
|
||||
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
|
||||
});
|
||||
|
||||
test("selecting 'Other' does not auto-advance", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
|
||||
|
||||
vi.advanceTimersByTime(500);
|
||||
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
|
||||
});
|
||||
|
||||
test("selecting 'Other' shows text input and Continue button", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
|
||||
|
||||
expect(screen.getByPlaceholderText(/describe your role/i)).toBeDefined();
|
||||
expect(screen.getByRole("button", { name: /continue/i })).toBeDefined();
|
||||
});
|
||||
|
||||
test("Continue button is disabled when Other input is empty", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
|
||||
|
||||
const continueBtn = screen.getByRole("button", { name: /continue/i });
|
||||
expect(continueBtn.hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
test("Continue button advances when Other role text is filled", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
|
||||
fireEvent.change(screen.getByPlaceholderText(/describe your role/i), {
|
||||
target: { value: "Designer" },
|
||||
});
|
||||
|
||||
const continueBtn = screen.getByRole("button", { name: /continue/i });
|
||||
expect(continueBtn.hasAttribute("disabled")).toBe(false);
|
||||
|
||||
fireEvent.click(continueBtn);
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
|
||||
});
|
||||
|
||||
test("switching from Other to a regular role cancels Other and auto-advances", () => {
|
||||
render(<RoleStep />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /\bother\b/i }));
|
||||
expect(screen.getByPlaceholderText(/describe your role/i)).toBeDefined();
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /marketing/i }));
|
||||
|
||||
expect(useOnboardingWizardStore.getState().role).toBe("Marketing");
|
||||
vi.advanceTimersByTime(350);
|
||||
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useOnboardingWizardStore } from "../store";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { MAX_PAIN_POINT_SELECTIONS, useOnboardingWizardStore } from "../store";
|
||||
|
||||
const ROLE_TOP_PICKS: Record<string, string[]> = {
|
||||
"Founder/CEO": [
|
||||
@@ -23,18 +24,38 @@ export function usePainPointsStep() {
|
||||
const role = useOnboardingWizardStore((s) => s.role);
|
||||
const painPoints = useOnboardingWizardStore((s) => s.painPoints);
|
||||
const otherPainPoint = useOnboardingWizardStore((s) => s.otherPainPoint);
|
||||
const togglePainPoint = useOnboardingWizardStore((s) => s.togglePainPoint);
|
||||
const storeToggle = useOnboardingWizardStore((s) => s.togglePainPoint);
|
||||
const setOtherPainPoint = useOnboardingWizardStore(
|
||||
(s) => s.setOtherPainPoint,
|
||||
);
|
||||
const nextStep = useOnboardingWizardStore((s) => s.nextStep);
|
||||
const [shaking, setShaking] = useState(false);
|
||||
const shakeTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (shakeTimer.current) clearTimeout(shakeTimer.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const topIDs = getTopPickIDs(role);
|
||||
const hasSomethingElse = painPoints.includes("Something else");
|
||||
const atLimit = painPoints.length >= MAX_PAIN_POINT_SELECTIONS;
|
||||
const canContinue =
|
||||
painPoints.length > 0 &&
|
||||
(!hasSomethingElse || Boolean(otherPainPoint.trim()));
|
||||
|
||||
function togglePainPoint(id: string) {
|
||||
const alreadySelected = painPoints.includes(id);
|
||||
if (!alreadySelected && atLimit) {
|
||||
if (shakeTimer.current) clearTimeout(shakeTimer.current);
|
||||
setShaking(true);
|
||||
shakeTimer.current = setTimeout(() => setShaking(false), 600);
|
||||
return;
|
||||
}
|
||||
storeToggle(id);
|
||||
}
|
||||
|
||||
function handleLaunch() {
|
||||
if (canContinue) {
|
||||
nextStep();
|
||||
@@ -48,6 +69,8 @@ export function usePainPointsStep() {
|
||||
togglePainPoint,
|
||||
setOtherPainPoint,
|
||||
hasSomethingElse,
|
||||
atLimit,
|
||||
shaking,
|
||||
canContinue,
|
||||
handleLaunch,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { create } from "zustand";
|
||||
|
||||
export const MAX_PAIN_POINT_SELECTIONS = 3;
|
||||
export type Step = 1 | 2 | 3 | 4;
|
||||
|
||||
interface OnboardingWizardState {
|
||||
@@ -40,6 +41,8 @@ export const useOnboardingWizardStore = create<OnboardingWizardState>(
|
||||
togglePainPoint(painPoint) {
|
||||
set((state) => {
|
||||
const exists = state.painPoints.includes(painPoint);
|
||||
if (!exists && state.painPoints.length >= MAX_PAIN_POINT_SELECTIONS)
|
||||
return state;
|
||||
return {
|
||||
painPoints: exists
|
||||
? state.painPoints.filter((p) => p !== painPoint)
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Gauge } from "@phosphor-icons/react/dist/ssr";
|
||||
import {
|
||||
Users,
|
||||
CurrencyDollar,
|
||||
MagnifyingGlass,
|
||||
Gauge,
|
||||
Receipt,
|
||||
FileText,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -15,18 +21,23 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
icon: <CurrencyDollar className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
icon: <MagnifyingGlass className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Rate Limits",
|
||||
href: "/admin/rate-limits",
|
||||
icon: <Gauge className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Platform Costs",
|
||||
href: "/admin/platform-costs",
|
||||
icon: <Receipt className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
cleanup,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { PlatformCostContent } from "../components/PlatformCostContent";
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import type { PlatformCostLogsResponse } from "@/app/api/__generated__/models/platformCostLogsResponse";
|
||||
|
||||
// Mock the generated Orval hooks so tests don't hit the network
|
||||
const mockUseGetDashboard = vi.fn();
|
||||
const mockUseGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
useGetV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockUseGetDashboard(...args),
|
||||
useGetV2GetPlatformCostLogs: (...args: unknown[]) => mockUseGetLogs(...args),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetDashboard.mockReset();
|
||||
mockUseGetLogs.mockReset();
|
||||
});
|
||||
|
||||
const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
|
||||
const emptyLogs: PlatformCostLogsResponse = {
|
||||
logs: [],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 0,
|
||||
total_pages: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 60,
|
||||
},
|
||||
{
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 40,
|
||||
},
|
||||
],
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const logsWithData: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-1",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
graph_exec_id: "gx-123",
|
||||
node_exec_id: "nx-456",
|
||||
block_name: "LLMBlock",
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
cost_microdollars: 5000,
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
duration: 1.5,
|
||||
model: "gpt-4",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
|
||||
function renderComponent(searchParams = {}) {
|
||||
return render(<PlatformCostContent searchParams={searchParams} />);
|
||||
}
|
||||
|
||||
describe("PlatformCostContent", () => {
|
||||
it("shows loading state initially", () => {
|
||||
mockUseGetDashboard.mockReturnValue({ data: undefined, isLoading: true });
|
||||
mockUseGetLogs.mockReturnValue({ data: undefined, isLoading: true });
|
||||
renderComponent();
|
||||
// Loading state renders Skeleton placeholders (animate-pulse divs) instead of content
|
||||
expect(screen.queryByText("Loading...")).toBeNull();
|
||||
// Summary cards and table content are not yet shown
|
||||
expect(screen.queryByText("Known Cost")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders empty dashboard", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: emptyLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(2);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders dashboard with provider data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("$5.0000")).toBeDefined();
|
||||
expect(screen.getByText("100")).toBeDefined();
|
||||
expect(screen.getByText("5")).toBeDefined();
|
||||
expect(screen.getByText("openai")).toBeDefined();
|
||||
expect(screen.getByText("google_maps")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tracking type badges", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("tokens")).toBeDefined();
|
||||
expect(screen.getByText("per_run")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows error state on fetch failure", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Network error")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tab buttons", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("By Provider")).toBeDefined();
|
||||
expect(screen.getByText("By User")).toBeDefined();
|
||||
expect(screen.getByText("Raw Logs")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders summary cards with correct labels", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Start Date")).toBeDefined();
|
||||
expect(screen.getByText("End Date")).toBeDefined();
|
||||
expect(screen.getAllByText(/Provider/i).length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("User ID")).toBeDefined();
|
||||
expect(screen.getByText("Apply")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders by-user tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders logs tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders no logs message when empty", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("No logs found")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows pagination when multiple pages", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
const multiPageLogs: PlatformCostLogsResponse = {
|
||||
logs: logsWithData.logs,
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 200,
|
||||
total_pages: 4,
|
||||
},
|
||||
};
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: multiPageLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Previous")).toBeDefined();
|
||||
expect(screen.getByText("Next")).toBeDefined();
|
||||
expect(screen.getByText(/Page 1 of 4/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table with unknown email", async () => {
|
||||
const dashWithNullEmail: PlatformCostDashboard = {
|
||||
...dashboardWithData,
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-2",
|
||||
email: null,
|
||||
total_cost_microdollars: 1000,
|
||||
total_input_tokens: 100,
|
||||
total_output_tokens: 50,
|
||||
request_count: 5,
|
||||
},
|
||||
],
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashWithNullEmail,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Unknown")).toBeDefined();
|
||||
});
|
||||
|
||||
it("by-user tab content visible when tab=by-user param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
// overview tab content should not be visible
|
||||
expect(screen.queryByText("openai")).toBeNull();
|
||||
});
|
||||
|
||||
it("logs tab content visible when tab=logs param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders log with null user as dash", async () => {
|
||||
const logWithNullUser: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-2",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: null,
|
||||
email: null,
|
||||
graph_exec_id: null,
|
||||
node_exec_id: null,
|
||||
block_name: "copilot:SDK",
|
||||
provider: "anthropic",
|
||||
tracking_type: "cost_usd",
|
||||
cost_microdollars: 15000,
|
||||
input_tokens: null,
|
||||
output_tokens: null,
|
||||
duration: null,
|
||||
model: "claude-opus-4-20250514",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logWithNullUser,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("copilot:SDK")).toBeDefined();
|
||||
expect(screen.getByText("anthropic")).toBeDefined();
|
||||
// null email + null user_id renders as "-" in the User column; multiple
|
||||
// other cells (tokens, duration, session) also render "-", so use
|
||||
// getAllByText to avoid the single-match constraint.
|
||||
expect(screen.getAllByText("-").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,87 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
const mockGetDashboard = vi.fn();
|
||||
const mockGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
getV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockGetDashboard(...args),
|
||||
getV2GetPlatformCostLogs: (...args: unknown[]) => mockGetLogs(...args),
|
||||
}));
|
||||
|
||||
import { getPlatformCostDashboard, getPlatformCostLogs } from "../actions";
|
||||
|
||||
describe("getPlatformCostDashboard", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { total_cost_microdollars: 1000, total_requests: 5 };
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("returns undefined on non-200", async () => {
|
||||
mockGetDashboard.mockResolvedValue({ status: 401 });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("passes filter params to API", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "2026-01-01T00:00:00",
|
||||
end: "2026-06-01T00:00:00",
|
||||
provider: "openai",
|
||||
user_id: "user-1",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-01-01T00:00:00");
|
||||
expect(params.end).toBe("2026-06-01T00:00:00");
|
||||
expect(params.provider).toBe("openai");
|
||||
expect(params.user_id).toBe("user-1");
|
||||
});
|
||||
|
||||
it("passes undefined for empty filter strings", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBeUndefined();
|
||||
expect(params.provider).toBeUndefined();
|
||||
expect(params.user_id).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPlatformCostLogs", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { logs: [], pagination: { current_page: 1 } };
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostLogs();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("passes page and page_size", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ page: 3, page_size: 25 });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.page).toBe(3);
|
||||
expect(params.page_size).toBe(25);
|
||||
});
|
||||
|
||||
it("passes start date string through to API", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ start: "2026-03-01T00:00:00" });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-03-01T00:00:00");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,300 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
toDateOrUndefined,
|
||||
formatMicrodollars,
|
||||
formatTokens,
|
||||
formatDuration,
|
||||
estimateCostForRow,
|
||||
trackingValue,
|
||||
toLocalInput,
|
||||
toUtcIso,
|
||||
} from "../helpers";
|
||||
|
||||
function makeRow(overrides: Partial<ProviderCostSummary>): ProviderCostSummary {
|
||||
return {
|
||||
provider: "openai",
|
||||
tracking_type: null,
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 0,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("toDateOrUndefined", () => {
|
||||
it("returns undefined for empty string", () => {
|
||||
expect(toDateOrUndefined("")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for undefined", () => {
|
||||
expect(toDateOrUndefined(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for invalid date string", () => {
|
||||
expect(toDateOrUndefined("not-a-date")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns a Date for a valid ISO string", () => {
|
||||
const result = toDateOrUndefined("2026-01-15T00:00:00Z");
|
||||
expect(result).toBeInstanceOf(Date);
|
||||
expect(result!.toISOString()).toBe("2026-01-15T00:00:00.000Z");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatMicrodollars", () => {
|
||||
it("formats zero", () => {
|
||||
expect(formatMicrodollars(0)).toBe("$0.0000");
|
||||
});
|
||||
|
||||
it("formats a small amount", () => {
|
||||
expect(formatMicrodollars(50_000)).toBe("$0.0500");
|
||||
});
|
||||
|
||||
it("formats one dollar", () => {
|
||||
expect(formatMicrodollars(1_000_000)).toBe("$1.0000");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatTokens", () => {
|
||||
it("formats small numbers as-is", () => {
|
||||
expect(formatTokens(500)).toBe("500");
|
||||
});
|
||||
|
||||
it("formats thousands with K suffix", () => {
|
||||
expect(formatTokens(1_500)).toBe("1.5K");
|
||||
});
|
||||
|
||||
it("formats millions with M suffix", () => {
|
||||
expect(formatTokens(2_500_000)).toBe("2.5M");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatDuration", () => {
|
||||
it("formats seconds", () => {
|
||||
expect(formatDuration(30)).toBe("30.0s");
|
||||
});
|
||||
|
||||
it("formats minutes", () => {
|
||||
expect(formatDuration(90)).toBe("1.5m");
|
||||
});
|
||||
|
||||
it("formats hours", () => {
|
||||
expect(formatDuration(5400)).toBe("1.5h");
|
||||
});
|
||||
});
|
||||
|
||||
describe("estimateCostForRow", () => {
|
||||
it("returns microdollars directly for cost_usd tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 500_000,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(500_000);
|
||||
});
|
||||
|
||||
it("returns reported cost for token tracking when cost > 0", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 100_000,
|
||||
total_input_tokens: 1000,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(100_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for token tracking with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
// 1000 tokens / 1000 * 0.005 USD * 1_000_000 = 5000
|
||||
expect(estimateCostForRow(row, {})).toBe(5000);
|
||||
});
|
||||
|
||||
it("returns null for unknown token provider with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "unknown_provider",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("uses per-run override when provided", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
request_count: 10,
|
||||
});
|
||||
// override = 0.05 * 10 * 1_000_000 = 500_000
|
||||
expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe(
|
||||
500_000,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses default per-run cost when no override", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: null,
|
||||
request_count: 5,
|
||||
});
|
||||
// 0.032 * 5 * 1_000_000 = 160_000
|
||||
expect(estimateCostForRow(row, {})).toBe(160_000);
|
||||
});
|
||||
|
||||
it("returns null for unknown per_run provider", () => {
|
||||
const row = makeRow({
|
||||
provider: "totally_unknown",
|
||||
tracking_type: "per_run",
|
||||
request_count: 3,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for duration tracking with no rate and no cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "duration_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 100,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "elevenlabs",
|
||||
tracking_type: "characters",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 2000,
|
||||
});
|
||||
// 2000 chars / 1000 * 0.18 USD * 1_000_000 = 360_000
|
||||
expect(estimateCostForRow(row, {})).toBe(360_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for items tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "apollo",
|
||||
tracking_type: "items",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 50,
|
||||
});
|
||||
// 50 * 0.02 * 1_000_000 = 1_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(1_000_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for duration tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "e2b",
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 1_000_000,
|
||||
});
|
||||
// 1_000_000 * 0.000014 * 1_000_000 = 14_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(14_000_000);
|
||||
});
|
||||
});
|
||||
|
||||
describe("trackingValue", () => {
|
||||
it("returns formatted microdollars for cost_usd", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 1_000_000,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("$1.0000");
|
||||
});
|
||||
|
||||
it("returns formatted token count for tokens", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1.0K tokens");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 120,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0m");
|
||||
});
|
||||
|
||||
it("returns run count for per_run (default tracking)", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: null,
|
||||
request_count: 42,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("42 runs");
|
||||
});
|
||||
|
||||
it("returns formatted character count for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "characters",
|
||||
total_tracking_amount: 2500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.5K chars");
|
||||
});
|
||||
|
||||
it("returns formatted item count for items tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "items",
|
||||
total_tracking_amount: 1234,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1,234 items");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 7200,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0h");
|
||||
});
|
||||
|
||||
it("returns formatted duration for walltime_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "walltime_seconds",
|
||||
total_duration_seconds: 45,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("45.0s");
|
||||
});
|
||||
});
|
||||
|
||||
describe("toLocalInput", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toLocalInput("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid ISO", () => {
|
||||
expect(toLocalInput("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts UTC ISO to local datetime-local format", () => {
|
||||
const result = toLocalInput("2026-01-15T12:30:00Z");
|
||||
// Format should be YYYY-MM-DDTHH:mm
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$/);
|
||||
});
|
||||
});
|
||||
|
||||
describe("toUtcIso", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toUtcIso("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid local time", () => {
|
||||
expect(toUtcIso("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts local datetime-local to ISO string", () => {
|
||||
const result = toUtcIso("2026-01-15T12:30");
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,45 @@
|
||||
import {
|
||||
getV2GetPlatformCostDashboard,
|
||||
getV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
// Backend expects ISO datetime strings. The generated client's URL builder
|
||||
// calls .toString() on values, which for Date objects produces the human
|
||||
// "Tue Mar 31 2026 22:00:00 GMT+0000 (Coordinated Universal Time)" format
|
||||
// that FastAPI rejects with 422. We already pass UTC ISO from the URL, so
|
||||
// forward the raw strings through the `as unknown as Date` cast to match
|
||||
// the generated typing without triggering Date.toString().
|
||||
export async function getPlatformCostDashboard(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostDashboard({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
export async function getPlatformCostLogs(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostLogs({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
page: params?.page,
|
||||
page_size: params?.page_size,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow";
|
||||
import type { Pagination } from "@/app/api/__generated__/models/pagination";
|
||||
import { formatDuration, formatMicrodollars, formatTokens } from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
function formatLogDate(value: unknown): string {
|
||||
if (value instanceof Date) return value.toLocaleString();
|
||||
if (typeof value === "string" || typeof value === "number")
|
||||
return new Date(value).toLocaleString();
|
||||
return "-";
|
||||
}
|
||||
|
||||
interface Props {
|
||||
logs: CostLogRow[];
|
||||
pagination: Pagination | null;
|
||||
onPageChange: (page: number) => void;
|
||||
}
|
||||
|
||||
function LogsTable({ logs, pagination, onPageChange }: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Time
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Block
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Model
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Cost
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Duration
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Execution
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{logs.map((log) => (
|
||||
<tr key={log.id} className="border-b hover:bg-muted">
|
||||
<td className="whitespace-nowrap px-3 py-2 text-xs">
|
||||
{formatLogDate(log.created_at)}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
{log.email ||
|
||||
(log.user_id ? String(log.user_id).slice(0, 8) : "-")}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs font-medium">
|
||||
{log.block_name}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.provider}</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
<TrackingBadge trackingType={log.tracking_type} />
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.model || "-"}</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.cost_microdollars != null
|
||||
? formatMicrodollars(Number(log.cost_microdollars))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.input_tokens != null || log.output_tokens != null
|
||||
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.duration != null
|
||||
? formatDuration(Number(log.duration))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs text-muted-foreground">
|
||||
{log.graph_exec_id
|
||||
? String(log.graph_exec_id).slice(0, 8)
|
||||
: "-"}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{logs.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={10}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No logs found
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{pagination && pagination.total_pages > 1 && (
|
||||
<div className="flex items-center justify-between px-4">
|
||||
<span className="text-sm text-muted-foreground">
|
||||
Page {pagination.current_page} of {pagination.total_pages} (
|
||||
{pagination.total_items} total)
|
||||
</span>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
disabled={pagination.current_page <= 1}
|
||||
onClick={() => onPageChange(pagination.current_page - 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<button
|
||||
disabled={pagination.current_page >= pagination.total_pages}
|
||||
onClick={() => onPageChange(pagination.current_page + 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { LogsTable };
|
||||
@@ -0,0 +1,234 @@
|
||||
"use client";
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { formatMicrodollars } from "../helpers";
|
||||
import { SummaryCard } from "./SummaryCard";
|
||||
import { ProviderTable } from "./ProviderTable";
|
||||
import { UserTable } from "./UserTable";
|
||||
import { LogsTable } from "./LogsTable";
|
||||
import { usePlatformCostContent } from "./usePlatformCostContent";
|
||||
|
||||
interface Props {
|
||||
searchParams: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function PlatformCostContent({ searchParams }: Props) {
|
||||
const {
|
||||
dashboard,
|
||||
logs,
|
||||
pagination,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
} = usePlatformCostContent(searchParams);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="start-date" className="text-sm text-muted-foreground">
|
||||
Start Date{" "}
|
||||
<span className="text-xs">
|
||||
(local time — defaults to last 30 days)
|
||||
</span>
|
||||
</label>
|
||||
<input
|
||||
id="start-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={startInput}
|
||||
onChange={(e) => setStartInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="end-date" className="text-sm text-muted-foreground">
|
||||
End Date <span className="text-xs">(local time)</span>
|
||||
</label>
|
||||
<input
|
||||
id="end-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={endInput}
|
||||
onChange={(e) => setEndInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="provider-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
Provider
|
||||
</label>
|
||||
<input
|
||||
id="provider-filter"
|
||||
type="text"
|
||||
placeholder="e.g. openai"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={providerInput}
|
||||
onChange={(e) => setProviderInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="user-id-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
User ID
|
||||
</label>
|
||||
<input
|
||||
id="user-id-filter"
|
||||
type="text"
|
||||
placeholder="Filter by user"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={userInput}
|
||||
onChange={(e) => setUserInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleFilter}
|
||||
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Apply
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setStartInput("");
|
||||
setEndInput("");
|
||||
setProviderInput("");
|
||||
setUserInput("");
|
||||
updateUrl({
|
||||
start: "",
|
||||
end: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
page: "1",
|
||||
});
|
||||
}}
|
||||
className="rounded border px-4 py-1.5 text-sm hover:bg-muted"
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<Alert variant="error">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-8 w-48 rounded" />
|
||||
<Skeleton className="h-64 rounded-lg" />
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{dashboard && (
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
role="tablist"
|
||||
aria-label="Cost view tabs"
|
||||
className="flex gap-2 border-b"
|
||||
>
|
||||
{["overview", "by-user", "logs"].map((t) => (
|
||||
<button
|
||||
key={t}
|
||||
id={`tab-${t}`}
|
||||
role="tab"
|
||||
aria-selected={tab === t}
|
||||
aria-controls={`tabpanel-${t}`}
|
||||
onClick={() => updateUrl({ tab: t, page: "1" })}
|
||||
className={`px-4 py-2 text-sm font-medium ${tab === t ? "border-b-2 border-primary text-primary" : "text-muted-foreground hover:text-foreground"}`}
|
||||
>
|
||||
{t === "overview"
|
||||
? "By Provider"
|
||||
: t === "by-user"
|
||||
? "By User"
|
||||
: "Raw Logs"}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{tab === "overview" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-overview"
|
||||
aria-labelledby="tab-overview"
|
||||
>
|
||||
<ProviderTable
|
||||
data={dashboard.by_provider}
|
||||
rateOverrides={rateOverrides}
|
||||
onRateOverride={handleRateOverride}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{tab === "by-user" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-by-user"
|
||||
aria-labelledby="tab-by-user"
|
||||
>
|
||||
<UserTable data={dashboard.by_user} />
|
||||
</div>
|
||||
)}
|
||||
{tab === "logs" && (
|
||||
<div role="tabpanel" id="tabpanel-logs" aria-labelledby="tab-logs">
|
||||
<LogsTable
|
||||
logs={logs}
|
||||
pagination={pagination}
|
||||
onPageChange={(p) => updateUrl({ page: p.toString() })}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
defaultRateFor,
|
||||
estimateCostForRow,
|
||||
formatMicrodollars,
|
||||
rateKey,
|
||||
rateUnitLabel,
|
||||
trackingValue,
|
||||
} from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
interface Props {
|
||||
data: ProviderCostSummary[];
|
||||
rateOverrides: Record<string, number>;
|
||||
onRateOverride: (key: string, val: number | null) => void;
|
||||
}
|
||||
|
||||
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Usage
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Est. Cost
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Per-session only"
|
||||
>
|
||||
Rate <span className="text-[10px] font-normal">(unsaved)</span>
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
const tt = row.tracking_type || "per_run";
|
||||
// For cost_usd rows the provider reports USD directly so rate
|
||||
// input doesn't apply; otherwise show an editable input.
|
||||
const showRateInput = tt !== "cost_usd";
|
||||
const key = rateKey(row.provider, tt);
|
||||
const fallback = defaultRateFor(row.provider, tt);
|
||||
const currentRate = rateOverrides[key] ?? fallback;
|
||||
return (
|
||||
<tr key={key} className="border-b hover:bg-muted">
|
||||
<td className="px-4 py-3 font-medium">{row.provider}</td>
|
||||
<td className="px-4 py-3">
|
||||
<TrackingBadge trackingType={row.tracking_type} />
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{est !== null ? (
|
||||
formatMicrodollars(est)
|
||||
) : (
|
||||
<span className="text-muted-foreground">-</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-4 py-2 text-right">
|
||||
{showRateInput ? (
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<input
|
||||
type="number"
|
||||
step="0.0001"
|
||||
min="0"
|
||||
aria-label={`Rate for ${row.provider} (${tt})`}
|
||||
className="w-24 rounded border px-2 py-1 text-right text-xs"
|
||||
placeholder={fallback !== null ? String(fallback) : "0"}
|
||||
value={currentRate ?? ""}
|
||||
onChange={(e) => {
|
||||
const val = parseFloat(e.target.value);
|
||||
if (!isNaN(val)) onRateOverride(key, val);
|
||||
else if (e.target.value === "")
|
||||
onRateOverride(key, null);
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
className="text-[10px] text-muted-foreground"
|
||||
title={rateUnitLabel(tt)}
|
||||
>
|
||||
{rateUnitLabel(tt)}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-xs text-muted-foreground">auto</span>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={7}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ProviderTable };
|
||||
@@ -0,0 +1,19 @@
|
||||
interface Props {
|
||||
label: string;
|
||||
value: string;
|
||||
subtitle?: string;
|
||||
}
|
||||
|
||||
function SummaryCard({ label, value, subtitle }: Props) {
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<div className="text-sm text-muted-foreground">{label}</div>
|
||||
<div className="text-2xl font-bold">{value}</div>
|
||||
{subtitle && (
|
||||
<div className="mt-1 text-xs text-muted-foreground">{subtitle}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { SummaryCard };
|
||||
@@ -0,0 +1,25 @@
|
||||
function TrackingBadge({
|
||||
trackingType,
|
||||
}: {
|
||||
trackingType: string | null | undefined;
|
||||
}) {
|
||||
const colors: Record<string, string> = {
|
||||
cost_usd: "bg-green-500/10 text-green-700",
|
||||
tokens: "bg-blue-500/10 text-blue-700",
|
||||
characters: "bg-purple-500/10 text-purple-700",
|
||||
sandbox_seconds: "bg-orange-500/10 text-orange-700",
|
||||
walltime_seconds: "bg-orange-500/10 text-orange-700",
|
||||
items: "bg-pink-500/10 text-pink-700",
|
||||
per_run: "bg-muted text-muted-foreground",
|
||||
};
|
||||
const label = trackingType || "per_run";
|
||||
return (
|
||||
<span
|
||||
className={`inline-block rounded px-1.5 py-0.5 text-[10px] font-medium ${colors[label] || colors.per_run}`}
|
||||
>
|
||||
{label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export { TrackingBadge };
|
||||
@@ -0,0 +1,75 @@
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import { formatMicrodollars, formatTokens } from "../helpers";
|
||||
|
||||
interface Props {
|
||||
data: PlatformCostDashboard["by_user"];
|
||||
}
|
||||
|
||||
function UserTable({ data }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Input Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Output Tokens
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row, idx) => (
|
||||
<tr
|
||||
key={row.user_id ?? `unknown-${idx}`}
|
||||
className="border-b hover:bg-muted"
|
||||
>
|
||||
<td className="px-4 py-3">
|
||||
<div className="font-medium">{row.email || "Unknown"}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{row.user_id}
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_input_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={5}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { UserTable };
|
||||
@@ -0,0 +1,136 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
useGetV2GetPlatformCostDashboard,
|
||||
useGetV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers";
|
||||
|
||||
interface InitialSearchParams {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
}
|
||||
|
||||
export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
const router = useRouter();
|
||||
const urlParams = useSearchParams();
|
||||
|
||||
const tab = urlParams.get("tab") || searchParams.tab || "overview";
|
||||
const page = parseInt(urlParams.get("page") || searchParams.page || "1", 10);
|
||||
const startDate = urlParams.get("start") || searchParams.start || "";
|
||||
const endDate = urlParams.get("end") || searchParams.end || "";
|
||||
const providerFilter =
|
||||
urlParams.get("provider") || searchParams.provider || "";
|
||||
const userFilter = urlParams.get("user_id") || searchParams.user_id || "";
|
||||
|
||||
const [startInput, setStartInput] = useState(toLocalInput(startDate));
|
||||
const [endInput, setEndInput] = useState(toLocalInput(endDate));
|
||||
const [providerInput, setProviderInput] = useState(providerFilter);
|
||||
const [userInput, setUserInput] = useState(userFilter);
|
||||
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
|
||||
{},
|
||||
);
|
||||
|
||||
// Pass ISO date strings through `as unknown as Date` so Orval's URL builder
|
||||
// forwards them as-is. Date.toString() produces a format FastAPI rejects;
|
||||
// strings pass through .toString() unchanged.
|
||||
const filterParams = {
|
||||
start: (startDate || undefined) as unknown as Date | undefined,
|
||||
end: (endDate || undefined) as unknown as Date | undefined,
|
||||
provider: providerFilter || undefined,
|
||||
user_id: userFilter || undefined,
|
||||
};
|
||||
|
||||
const {
|
||||
data: dashboard,
|
||||
isLoading: dashLoading,
|
||||
error: dashError,
|
||||
} = useGetV2GetPlatformCostDashboard(filterParams, {
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const {
|
||||
data: logsResponse,
|
||||
isLoading: logsLoading,
|
||||
error: logsError,
|
||||
} = useGetV2GetPlatformCostLogs(
|
||||
{ ...filterParams, page, page_size: 50 },
|
||||
{ query: { select: okData } },
|
||||
);
|
||||
|
||||
const loading = dashLoading || logsLoading;
|
||||
const error = dashError
|
||||
? dashError instanceof Error
|
||||
? dashError.message
|
||||
: "Failed to load dashboard"
|
||||
: logsError
|
||||
? logsError instanceof Error
|
||||
? logsError.message
|
||||
: "Failed to load logs"
|
||||
: null;
|
||||
|
||||
function updateUrl(overrides: Record<string, string>) {
|
||||
const params = new URLSearchParams(urlParams.toString());
|
||||
for (const [k, v] of Object.entries(overrides)) {
|
||||
if (v) params.set(k, v);
|
||||
else params.delete(k);
|
||||
}
|
||||
router.push(`/admin/platform-costs?${params.toString()}`);
|
||||
}
|
||||
|
||||
function handleFilter() {
|
||||
updateUrl({
|
||||
start: toUtcIso(startInput),
|
||||
end: toUtcIso(endInput),
|
||||
provider: providerInput,
|
||||
user_id: userInput,
|
||||
page: "1",
|
||||
});
|
||||
}
|
||||
|
||||
function handleRateOverride(key: string, val: number | null) {
|
||||
setRateOverrides((prev) => {
|
||||
if (val === null) {
|
||||
const { [key]: _, ...rest } = prev;
|
||||
return rest;
|
||||
}
|
||||
return { ...prev, [key]: val };
|
||||
});
|
||||
}
|
||||
|
||||
const totalEstimatedCost =
|
||||
dashboard?.by_provider.reduce((sum, row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
return sum + (est ?? 0);
|
||||
}, 0) ?? 0;
|
||||
|
||||
return {
|
||||
dashboard: dashboard ?? null,
|
||||
logs: logsResponse?.logs ?? [],
|
||||
pagination: logsResponse?.pagination ?? null,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
page,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
|
||||
const MICRODOLLARS_PER_USD = 1_000_000;
|
||||
|
||||
// Per-request cost estimates (USD) for providers billed per API call.
|
||||
export const DEFAULT_COST_PER_RUN: Record<string, number> = {
|
||||
google_maps: 0.032, // $0.032/request - Google Maps Places API
|
||||
ideogram: 0.08, // $0.08/image - Ideogram standard generation
|
||||
nvidia: 0.0, // Free tier - NVIDIA NIM deepfake detection
|
||||
screenshotone: 0.01, // ~$0.01/screenshot - ScreenshotOne starter
|
||||
zerobounce: 0.008, // $0.008/validation - ZeroBounce
|
||||
mem0: 0.01, // ~$0.01/request - Mem0
|
||||
openweathermap: 0.0, // Free tier
|
||||
webshare_proxy: 0.0, // Flat subscription
|
||||
enrichlayer: 0.1, // ~$0.10/profile lookup
|
||||
jina: 0.0, // Free tier
|
||||
};
|
||||
|
||||
export const DEFAULT_COST_PER_1K_TOKENS: Record<string, number> = {
|
||||
openai: 0.005,
|
||||
anthropic: 0.008,
|
||||
groq: 0.0003,
|
||||
ollama: 0.0,
|
||||
aiml_api: 0.005,
|
||||
llama_api: 0.003,
|
||||
v0: 0.005,
|
||||
};
|
||||
|
||||
// Per-character rates (USD / 1K characters) for TTS providers.
|
||||
export const DEFAULT_COST_PER_1K_CHARS: Record<string, number> = {
|
||||
unreal_speech: 0.008, // ~$8/1M chars on Starter
|
||||
elevenlabs: 0.18, // ~$0.18/1K chars on Starter
|
||||
d_id: 0.04, // ~$0.04/1K chars estimated
|
||||
};
|
||||
|
||||
// Per-item rates (USD / item) for item-count billed APIs.
|
||||
export const DEFAULT_COST_PER_ITEM: Record<string, number> = {
|
||||
google_maps: 0.017, // avg of $0.032 nearby + ~$0.015 detail enrich
|
||||
apollo: 0.02, // ~$0.02/contact on low-volume tiers
|
||||
smartlead: 0.001, // ~$0.001/lead added
|
||||
};
|
||||
|
||||
// Per-second rates (USD / second) for duration-billed providers.
|
||||
export const DEFAULT_COST_PER_SECOND: Record<string, number> = {
|
||||
e2b: 0.000014, // $0.000014/sec (2-core sandbox)
|
||||
fal: 0.0005, // varies by model, conservative
|
||||
replicate: 0.001, // varies by hardware
|
||||
revid: 0.01, // per-second of video
|
||||
};
|
||||
|
||||
export function toDateOrUndefined(val?: string): Date | undefined {
|
||||
if (!val) return undefined;
|
||||
const d = new Date(val);
|
||||
return isNaN(d.getTime()) ? undefined : d;
|
||||
}
|
||||
|
||||
export function formatMicrodollars(microdollars: number) {
|
||||
return `$${(microdollars / MICRODOLLARS_PER_USD).toFixed(4)}`;
|
||||
}
|
||||
|
||||
export function formatTokens(tokens: number) {
|
||||
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
|
||||
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`;
|
||||
return tokens.toString();
|
||||
}
|
||||
|
||||
export function formatDuration(seconds: number) {
|
||||
if (seconds >= 3600) return `${(seconds / 3600).toFixed(1)}h`;
|
||||
if (seconds >= 60) return `${(seconds / 60).toFixed(1)}m`;
|
||||
return `${seconds.toFixed(1)}s`;
|
||||
}
|
||||
|
||||
// Unit label for each tracking type — what the rate input represents.
|
||||
export function rateUnitLabel(trackingType: string | null | undefined): string {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return "$/1K tokens";
|
||||
case "characters":
|
||||
return "$/1K chars";
|
||||
case "items":
|
||||
return "$/item";
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return "$/second";
|
||||
case "per_run":
|
||||
return "$/run";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
// Default rate for a (provider, tracking_type) pair.
|
||||
export function defaultRateFor(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): number | null {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return DEFAULT_COST_PER_1K_TOKENS[provider] ?? null;
|
||||
case "characters":
|
||||
return DEFAULT_COST_PER_1K_CHARS[provider] ?? null;
|
||||
case "items":
|
||||
return DEFAULT_COST_PER_ITEM[provider] ?? null;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return DEFAULT_COST_PER_SECOND[provider] ?? null;
|
||||
case "per_run":
|
||||
return DEFAULT_COST_PER_RUN[provider] ?? null;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Overrides are keyed on `${provider}:${tracking_type}` since the same
|
||||
// provider can have multiple rows with different billing models.
|
||||
export function rateKey(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): string {
|
||||
return `${provider}:${trackingType ?? "per_run"}`;
|
||||
}
|
||||
|
||||
export function estimateCostForRow(
|
||||
row: ProviderCostSummary,
|
||||
rateOverrides: Record<string, number>,
|
||||
) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
|
||||
// Providers that report USD directly: use known cost.
|
||||
if (tt === "cost_usd") return row.total_cost_microdollars;
|
||||
|
||||
// Prefer the real USD the provider reported if any, but only for token paths
|
||||
// where OpenRouter piggybacks on the tokens row via x-total-cost.
|
||||
if (tt === "tokens" && row.total_cost_microdollars > 0) {
|
||||
return row.total_cost_microdollars;
|
||||
}
|
||||
|
||||
const rate =
|
||||
rateOverrides[rateKey(row.provider, tt)] ??
|
||||
defaultRateFor(row.provider, tt);
|
||||
if (rate === null || rate === undefined) return null;
|
||||
|
||||
// Compute the amount for this tracking type, then multiply by rate.
|
||||
let amount: number;
|
||||
switch (tt) {
|
||||
case "tokens":
|
||||
// Rate is per-1K tokens.
|
||||
amount = (row.total_input_tokens + row.total_output_tokens) / 1000;
|
||||
break;
|
||||
case "characters":
|
||||
// Rate is per-1K chars. trackingAmount aggregates char counts.
|
||||
amount = (row.total_tracking_amount || 0) / 1000;
|
||||
break;
|
||||
case "items":
|
||||
amount = row.total_tracking_amount || 0;
|
||||
break;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
amount = row.total_duration_seconds || 0;
|
||||
break;
|
||||
case "per_run":
|
||||
amount = row.request_count;
|
||||
break;
|
||||
default:
|
||||
return row.total_cost_microdollars > 0
|
||||
? row.total_cost_microdollars
|
||||
: null;
|
||||
}
|
||||
|
||||
return Math.round(rate * amount * MICRODOLLARS_PER_USD);
|
||||
}
|
||||
|
||||
export function trackingValue(row: ProviderCostSummary) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars);
|
||||
if (tt === "tokens") {
|
||||
const tokens = row.total_input_tokens + row.total_output_tokens;
|
||||
return `${formatTokens(tokens)} tokens`;
|
||||
}
|
||||
if (tt === "sandbox_seconds" || tt === "walltime_seconds")
|
||||
return formatDuration(row.total_duration_seconds || 0);
|
||||
if (tt === "characters")
|
||||
return `${formatTokens(Math.round(row.total_tracking_amount || 0))} chars`;
|
||||
if (tt === "items")
|
||||
return `${Math.round(row.total_tracking_amount || 0).toLocaleString()} items`;
|
||||
return `${row.request_count.toLocaleString()} runs`;
|
||||
}
|
||||
|
||||
// URL holds UTC ISO; datetime-local inputs need local "YYYY-MM-DDTHH:mm".
|
||||
export function toLocalInput(iso: string) {
|
||||
if (!iso) return "";
|
||||
const d = new Date(iso);
|
||||
if (isNaN(d.getTime())) return "";
|
||||
const pad = (n: number) => String(n).padStart(2, "0");
|
||||
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())}T${pad(d.getHours())}:${pad(d.getMinutes())}`;
|
||||
}
|
||||
|
||||
// datetime-local emits naive local time; convert to UTC ISO so the
|
||||
// backend filter window matches what the admin sees in their browser.
|
||||
export function toUtcIso(local: string) {
|
||||
if (!local) return "";
|
||||
const d = new Date(local);
|
||||
return isNaN(d.getTime()) ? "" : d.toISOString();
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import { PlatformCostContent } from "./components/PlatformCostContent";
|
||||
|
||||
type SearchParams = {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
|
||||
function PlatformCostDashboard({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: SearchParams;
|
||||
}) {
|
||||
return (
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">Platform Costs</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Track real API costs incurred by system credentials across providers
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Suspense
|
||||
key={JSON.stringify(searchParams)}
|
||||
fallback={
|
||||
<div className="py-10 text-center">Loading cost data...</div>
|
||||
}
|
||||
>
|
||||
<PlatformCostContent searchParams={searchParams} />
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function PlatformCostDashboardPage({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: Promise<SearchParams>;
|
||||
}) {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedDashboard = await withAdminAccess(PlatformCostDashboard);
|
||||
return <ProtectedDashboard searchParams={await searchParams} />;
|
||||
}
|
||||
@@ -40,14 +40,14 @@ export const ContentRenderer: React.FC<{
|
||||
!shortContent
|
||||
) {
|
||||
return (
|
||||
<div className="overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words">
|
||||
<div className="overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words">
|
||||
{renderer?.render(value, metadata)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs">
|
||||
<div className="overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs">
|
||||
<TextRenderer value={value} truncateLengthLimit={200} />
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -8,6 +8,7 @@ import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import dynamic from "next/dynamic";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
@@ -20,6 +21,14 @@ import { RateLimitResetDialog } from "./components/RateLimitResetDialog/RateLimi
|
||||
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
const ArtifactPanel = dynamic(
|
||||
() =>
|
||||
import("./components/ArtifactPanel/ArtifactPanel").then(
|
||||
(m) => m.ArtifactPanel,
|
||||
),
|
||||
{ ssr: false },
|
||||
);
|
||||
|
||||
export function CopilotPage() {
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [droppedFiles, setDroppedFiles] = useState<File[]>([]);
|
||||
@@ -80,6 +89,10 @@ export function CopilotPage() {
|
||||
isUploadingFiles,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
// Pagination
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
@@ -116,6 +129,7 @@ export function CopilotPage() {
|
||||
const resetCost = usage?.reset_cost;
|
||||
|
||||
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
const isArtifactsEnabled = useGetFlag(Flag.ARTIFACTS);
|
||||
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
|
||||
const hasInsufficientCredits =
|
||||
credits !== null && resetCost != null && credits < resetCost;
|
||||
@@ -150,48 +164,55 @@ export function CopilotPage() {
|
||||
className="h-[calc(100vh-72px)] min-h-0"
|
||||
>
|
||||
{!isMobile && <ChatSidebar />}
|
||||
<div
|
||||
className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0"
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<NotificationBanner />
|
||||
{/* Drop overlay */}
|
||||
<div className="flex h-full w-full flex-row overflow-hidden">
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
|
||||
isDragging ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
className="relative flex min-w-0 flex-1 flex-col overflow-hidden bg-[#f8f8f9] px-0"
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
|
||||
<span className="text-lg font-medium text-violet-600">
|
||||
Drop files here
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ChatContainer
|
||||
messages={messages}
|
||||
status={status}
|
||||
error={error}
|
||||
sessionId={sessionId}
|
||||
isLoadingSession={isLoadingSession}
|
||||
isSessionError={isSessionError}
|
||||
isCreatingSession={isCreatingSession}
|
||||
isReconnecting={isReconnecting}
|
||||
isSyncing={isSyncing}
|
||||
onCreateSession={createSession}
|
||||
onSend={onSend}
|
||||
onStop={stop}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<NotificationBanner />
|
||||
{/* Drop overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
|
||||
isDragging ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
|
||||
<span className="text-lg font-medium text-violet-600">
|
||||
Drop files here
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ChatContainer
|
||||
messages={messages}
|
||||
status={status}
|
||||
error={error}
|
||||
sessionId={sessionId}
|
||||
isLoadingSession={isLoadingSession}
|
||||
isSessionError={isSessionError}
|
||||
isCreatingSession={isCreatingSession}
|
||||
isReconnecting={isReconnecting}
|
||||
isSyncing={isSyncing}
|
||||
onCreateSession={createSession}
|
||||
onSend={onSend}
|
||||
onStop={stop}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={loadMore}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{!isMobile && isArtifactsEnabled && <ArtifactPanel />}
|
||||
</div>
|
||||
{isMobile && isArtifactsEnabled && <ArtifactPanel mobile />}
|
||||
{isMobile && (
|
||||
<MobileDrawer
|
||||
isOpen={isDrawerOpen}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
"use client";
|
||||
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CaretRight, DownloadSimple } from "@phosphor-icons/react";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { downloadArtifact } from "../ArtifactPanel/downloadArtifact";
|
||||
import { classifyArtifact } from "../ArtifactPanel/helpers";
|
||||
|
||||
interface Props {
|
||||
artifact: ArtifactRef;
|
||||
}
|
||||
|
||||
function formatSize(bytes?: number): string {
|
||||
if (!bytes) return "";
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
|
||||
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
|
||||
}
|
||||
|
||||
export function ArtifactCard({ artifact }: Props) {
|
||||
const activeID = useCopilotUIStore((s) => s.artifactPanel.activeArtifact?.id);
|
||||
const isOpen = useCopilotUIStore((s) => s.artifactPanel.isOpen);
|
||||
const openArtifact = useCopilotUIStore((s) => s.openArtifact);
|
||||
|
||||
const isActive = isOpen && activeID === artifact.id;
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
artifact.sizeBytes,
|
||||
);
|
||||
const Icon = classification.icon;
|
||||
|
||||
function handleDownloadOnly() {
|
||||
downloadArtifact(artifact).catch(() => {
|
||||
toast({
|
||||
title: "Download failed",
|
||||
description: "Couldn't fetch the file.",
|
||||
variant: "destructive",
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if (!classification.openable) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleDownloadOnly}
|
||||
className="my-1 flex w-full items-center gap-3 rounded-lg border border-zinc-200 bg-white px-3 py-2.5 text-left transition-colors hover:bg-zinc-50"
|
||||
>
|
||||
<Icon size={20} className="shrink-0 text-zinc-400" />
|
||||
<div className="min-w-0 flex-1">
|
||||
<p className="truncate text-sm font-medium text-zinc-900">
|
||||
{artifact.title}
|
||||
</p>
|
||||
<p className="text-xs text-zinc-400">
|
||||
{classification.label}
|
||||
{artifact.sizeBytes
|
||||
? ` \u2022 ${formatSize(artifact.sizeBytes)}`
|
||||
: ""}
|
||||
</p>
|
||||
</div>
|
||||
<DownloadSimple size={16} className="shrink-0 text-zinc-400" />
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => openArtifact(artifact)}
|
||||
className={cn(
|
||||
"my-1 flex w-full items-center gap-3 rounded-lg border bg-white px-3 py-2.5 text-left transition-colors hover:bg-zinc-50",
|
||||
isActive ? "border-violet-300 bg-violet-50/50" : "border-zinc-200",
|
||||
)}
|
||||
>
|
||||
<Icon
|
||||
size={20}
|
||||
className={cn(
|
||||
"shrink-0",
|
||||
isActive ? "text-violet-500" : "text-zinc-400",
|
||||
)}
|
||||
/>
|
||||
<div className="min-w-0 flex-1">
|
||||
<p className="truncate text-sm font-medium text-zinc-900">
|
||||
{artifact.title}
|
||||
</p>
|
||||
<p className="text-xs text-zinc-400">
|
||||
<span
|
||||
className={cn(
|
||||
"inline-block rounded-full px-1.5 py-0.5 text-xs font-medium",
|
||||
artifact.origin === "user-upload"
|
||||
? "bg-blue-50 text-blue-500"
|
||||
: "bg-violet-50 text-violet-500",
|
||||
)}
|
||||
>
|
||||
{classification.label}
|
||||
</span>
|
||||
{artifact.sizeBytes
|
||||
? ` \u2022 ${formatSize(artifact.sizeBytes)}`
|
||||
: ""}
|
||||
</p>
|
||||
</div>
|
||||
<CaretRight
|
||||
size={16}
|
||||
className={cn(
|
||||
"shrink-0",
|
||||
isActive ? "text-violet-400" : "text-zinc-300",
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Sheet,
|
||||
SheetContent,
|
||||
SheetHeader,
|
||||
SheetTitle,
|
||||
} from "@/components/ui/sheet";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { ArtifactContent } from "./components/ArtifactContent";
|
||||
import { ArtifactDragHandle } from "./components/ArtifactDragHandle";
|
||||
import { ArtifactMinimizedStrip } from "./components/ArtifactMinimizedStrip";
|
||||
import { ArtifactPanelHeader } from "./components/ArtifactPanelHeader";
|
||||
import { useArtifactPanel } from "./useArtifactPanel";
|
||||
|
||||
interface Props {
|
||||
mobile?: boolean;
|
||||
}
|
||||
|
||||
export function ArtifactPanel({ mobile }: Props) {
|
||||
const {
|
||||
isOpen,
|
||||
isMinimized,
|
||||
isMaximized,
|
||||
activeArtifact,
|
||||
history,
|
||||
effectiveWidth,
|
||||
isSourceView,
|
||||
classification,
|
||||
setIsSourceView,
|
||||
closeArtifactPanel,
|
||||
minimizeArtifactPanel,
|
||||
maximizeArtifactPanel,
|
||||
restoreArtifactPanel,
|
||||
setArtifactPanelWidth,
|
||||
goBackArtifact,
|
||||
canCopy,
|
||||
handleCopy,
|
||||
handleDownload,
|
||||
} = useArtifactPanel();
|
||||
|
||||
if (!activeArtifact || !classification) return null;
|
||||
|
||||
const headerProps = {
|
||||
artifact: activeArtifact,
|
||||
classification,
|
||||
canGoBack: history.length > 0,
|
||||
isMaximized,
|
||||
isSourceView,
|
||||
hasSourceToggle: classification.hasSourceToggle,
|
||||
mobile: !!mobile,
|
||||
canCopy,
|
||||
onBack: goBackArtifact,
|
||||
onClose: closeArtifactPanel,
|
||||
onMinimize: minimizeArtifactPanel,
|
||||
onMaximize: maximizeArtifactPanel,
|
||||
onRestore: restoreArtifactPanel,
|
||||
onCopy: handleCopy,
|
||||
onDownload: handleDownload,
|
||||
onSourceToggle: setIsSourceView,
|
||||
};
|
||||
|
||||
// Mobile: fullscreen Sheet overlay
|
||||
if (mobile) {
|
||||
return (
|
||||
<Sheet
|
||||
open={isOpen}
|
||||
onOpenChange={(open) => !open && closeArtifactPanel()}
|
||||
>
|
||||
<SheetContent
|
||||
side="right"
|
||||
className="flex w-full flex-col p-0 sm:max-w-full"
|
||||
>
|
||||
<SheetHeader className="sr-only">
|
||||
<SheetTitle>{activeArtifact.title}</SheetTitle>
|
||||
</SheetHeader>
|
||||
<ArtifactPanelHeader {...headerProps} />
|
||||
<ArtifactContent
|
||||
artifact={activeArtifact}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</SheetContent>
|
||||
</Sheet>
|
||||
);
|
||||
}
|
||||
|
||||
// Minimized strip
|
||||
if (isOpen && isMinimized) {
|
||||
return (
|
||||
<ArtifactMinimizedStrip
|
||||
artifact={activeArtifact}
|
||||
classification={classification}
|
||||
onExpand={restoreArtifactPanel}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Keep AnimatePresence mounted across the open→closed transition so the
|
||||
// exit animation on the motion.div has a chance to run.
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<motion.div
|
||||
key="artifact-panel"
|
||||
data-artifact-panel
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.25, ease: "easeInOut" }}
|
||||
className="relative flex h-full flex-col overflow-hidden border-l border-zinc-200 bg-white"
|
||||
style={{ width: effectiveWidth }}
|
||||
>
|
||||
<ArtifactDragHandle onWidthChange={setArtifactPanelWidth} />
|
||||
<ArtifactPanelHeader {...headerProps} />
|
||||
<ArtifactContent
|
||||
artifact={activeArtifact}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
"use client";
|
||||
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
|
||||
import { Suspense } from "react";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
import { ArtifactSkeleton } from "./ArtifactSkeleton";
|
||||
import {
|
||||
TAILWIND_CDN_URL,
|
||||
wrapWithHeadInjection,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
import { useArtifactContent } from "./useArtifactContent";
|
||||
|
||||
interface Props {
|
||||
artifact: ArtifactRef;
|
||||
isSourceView: boolean;
|
||||
classification: ArtifactClassification;
|
||||
}
|
||||
|
||||
function ArtifactContentLoader({
|
||||
artifact,
|
||||
isSourceView,
|
||||
classification,
|
||||
}: Props) {
|
||||
const { content, pdfUrl, isLoading, error, scrollRef, retry } =
|
||||
useArtifactContent(artifact, classification);
|
||||
|
||||
if (isLoading) {
|
||||
return <ArtifactSkeleton extraLine />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm text-zinc-500">Failed to load content</p>
|
||||
<p className="text-xs text-zinc-400">{error}</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={retry}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div ref={scrollRef} className="flex-1 overflow-y-auto">
|
||||
<ArtifactRenderer
|
||||
artifact={artifact}
|
||||
content={content}
|
||||
pdfUrl={pdfUrl}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactRenderer({
|
||||
artifact,
|
||||
content,
|
||||
pdfUrl,
|
||||
isSourceView,
|
||||
classification,
|
||||
}: {
|
||||
artifact: ArtifactRef;
|
||||
content: string | null;
|
||||
pdfUrl: string | null;
|
||||
isSourceView: boolean;
|
||||
classification: ArtifactClassification;
|
||||
}) {
|
||||
// Image: render directly from URL (no content fetch)
|
||||
if (classification.type === "image") {
|
||||
return (
|
||||
<div className="flex items-center justify-center p-4">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={artifact.sourceUrl}
|
||||
alt={artifact.title}
|
||||
className="max-h-full max-w-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (classification.type === "pdf" && pdfUrl) {
|
||||
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
|
||||
// (Chromium bug #413851). The blob URL has a null origin so it can't
|
||||
// access the parent page regardless.
|
||||
return (
|
||||
<iframe src={pdfUrl} className="h-full w-full" title={artifact.title} />
|
||||
);
|
||||
}
|
||||
|
||||
if (content === null) return null;
|
||||
|
||||
// Source view: always show raw text
|
||||
if (isSourceView) {
|
||||
return (
|
||||
<pre className="whitespace-pre-wrap break-words p-4 font-mono text-sm text-zinc-800">
|
||||
{content}
|
||||
</pre>
|
||||
);
|
||||
}
|
||||
|
||||
if (classification.type === "html") {
|
||||
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
|
||||
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
|
||||
const wrapped = wrapWithHeadInjection(content, tailwindScript);
|
||||
return (
|
||||
<iframe
|
||||
sandbox="allow-scripts"
|
||||
srcDoc={wrapped}
|
||||
className="h-full w-full border-0"
|
||||
title={artifact.title}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (classification.type === "react") {
|
||||
return <ArtifactReactPreview source={content} title={artifact.title} />;
|
||||
}
|
||||
|
||||
// Code: pass with explicit type metadata so CodeRenderer matches
|
||||
// (prevents higher-priority MarkdownRenderer from claiming it)
|
||||
if (classification.type === "code") {
|
||||
const ext = artifact.title.split(".").pop() ?? "";
|
||||
const codeMeta = {
|
||||
mimeType: artifact.mimeType ?? undefined,
|
||||
filename: artifact.title,
|
||||
type: "code",
|
||||
language: ext,
|
||||
};
|
||||
return <div className="p-4">{codeRenderer.render(content, codeMeta)}</div>;
|
||||
}
|
||||
|
||||
// JSON: parse first so the JSONRenderer gets an object, not a string
|
||||
// (prevents higher-priority MarkdownRenderer from claiming it)
|
||||
if (classification.type === "json") {
|
||||
try {
|
||||
const parsed = JSON.parse(content);
|
||||
const jsonMeta = {
|
||||
mimeType: "application/json",
|
||||
type: "json",
|
||||
filename: artifact.title,
|
||||
};
|
||||
const jsonRenderer = globalRegistry.getRenderer(parsed, jsonMeta);
|
||||
if (jsonRenderer) {
|
||||
return (
|
||||
<div className="p-4">{jsonRenderer.render(parsed, jsonMeta)}</div>
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
// invalid JSON — fall through to plain text
|
||||
}
|
||||
}
|
||||
|
||||
// CSV: pass with explicit metadata so CSVRenderer matches
|
||||
if (classification.type === "csv") {
|
||||
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
|
||||
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
|
||||
if (csvRenderer) {
|
||||
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;
|
||||
}
|
||||
}
|
||||
|
||||
// Try the global renderer registry
|
||||
const metadata = {
|
||||
mimeType: artifact.mimeType ?? undefined,
|
||||
filename: artifact.title,
|
||||
};
|
||||
const renderer = globalRegistry.getRenderer(content, metadata);
|
||||
if (renderer) {
|
||||
return <div className="p-4">{renderer.render(content, metadata)}</div>;
|
||||
}
|
||||
|
||||
// Fallback: plain text
|
||||
return (
|
||||
<pre className="whitespace-pre-wrap break-words p-4 font-mono text-sm text-zinc-800">
|
||||
{content}
|
||||
</pre>
|
||||
);
|
||||
}
|
||||
|
||||
export function ArtifactContent(props: Props) {
|
||||
return (
|
||||
<Suspense fallback={<ArtifactSkeleton />}>
|
||||
<ArtifactContentLoader {...props} />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { DEFAULT_PANEL_WIDTH } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
onWidthChange: (width: number) => void;
|
||||
minWidth?: number;
|
||||
maxWidthPercent?: number;
|
||||
}
|
||||
|
||||
export function ArtifactDragHandle({
|
||||
onWidthChange,
|
||||
minWidth = 320,
|
||||
maxWidthPercent = 85,
|
||||
}: Props) {
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const startXRef = useRef(0);
|
||||
const startWidthRef = useRef(0);
|
||||
// Use refs for the callback + bounds so the drag listeners can read the
|
||||
// latest values without having to detach/reattach between re-renders.
|
||||
const onWidthChangeRef = useRef(onWidthChange);
|
||||
const minWidthRef = useRef(minWidth);
|
||||
const maxWidthPercentRef = useRef(maxWidthPercent);
|
||||
onWidthChangeRef.current = onWidthChange;
|
||||
minWidthRef.current = minWidth;
|
||||
maxWidthPercentRef.current = maxWidthPercent;
|
||||
|
||||
// Attach document listeners only while dragging, and always tear them down
|
||||
// on unmount — otherwise closing the panel mid-drag leaves listeners bound
|
||||
// to a handler that calls setState on the unmounted component.
|
||||
useEffect(() => {
|
||||
if (!isDragging) return;
|
||||
|
||||
function handlePointerMove(moveEvent: PointerEvent) {
|
||||
const delta = startXRef.current - moveEvent.clientX;
|
||||
const maxWidth = window.innerWidth * (maxWidthPercentRef.current / 100);
|
||||
const newWidth = Math.min(
|
||||
maxWidth,
|
||||
Math.max(minWidthRef.current, startWidthRef.current + delta),
|
||||
);
|
||||
onWidthChangeRef.current(newWidth);
|
||||
}
|
||||
|
||||
function handlePointerUp() {
|
||||
setIsDragging(false);
|
||||
}
|
||||
|
||||
document.addEventListener("pointermove", handlePointerMove);
|
||||
document.addEventListener("pointerup", handlePointerUp);
|
||||
document.addEventListener("pointercancel", handlePointerUp);
|
||||
return () => {
|
||||
document.removeEventListener("pointermove", handlePointerMove);
|
||||
document.removeEventListener("pointerup", handlePointerUp);
|
||||
document.removeEventListener("pointercancel", handlePointerUp);
|
||||
};
|
||||
}, [isDragging]);
|
||||
|
||||
function handlePointerDown(e: React.PointerEvent) {
|
||||
e.preventDefault();
|
||||
startXRef.current = e.clientX;
|
||||
|
||||
// Get the panel's current width from its parent
|
||||
const panel = (e.target as HTMLElement).closest(
|
||||
"[data-artifact-panel]",
|
||||
) as HTMLElement | null;
|
||||
startWidthRef.current = panel?.offsetWidth ?? DEFAULT_PANEL_WIDTH;
|
||||
|
||||
setIsDragging(true);
|
||||
}
|
||||
|
||||
return (
|
||||
// 12px transparent hit target with the visible 1px line centered inside
|
||||
// (WCAG-compliant, matches ~8-12px conventions of other resizable panels).
|
||||
<div
|
||||
role="separator"
|
||||
aria-orientation="vertical"
|
||||
aria-label="Resize panel"
|
||||
className={cn(
|
||||
"group absolute -left-1.5 top-0 z-10 flex h-full w-3 cursor-col-resize items-stretch justify-center",
|
||||
)}
|
||||
onPointerDown={handlePointerDown}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"h-full w-px bg-transparent transition-colors group-hover:w-0.5 group-hover:bg-violet-400",
|
||||
isDragging && "w-0.5 bg-violet-500",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowsOutSimple } from "@phosphor-icons/react";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
|
||||
interface Props {
|
||||
artifact: ArtifactRef;
|
||||
classification: ArtifactClassification;
|
||||
onExpand: () => void;
|
||||
}
|
||||
|
||||
export function ArtifactMinimizedStrip({
|
||||
artifact,
|
||||
classification,
|
||||
onExpand,
|
||||
}: Props) {
|
||||
const Icon = classification.icon;
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-10 flex-col items-center border-l border-zinc-200 bg-white pt-3">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onExpand}
|
||||
className="rounded p-1.5 text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
title="Expand panel"
|
||||
>
|
||||
<ArrowsOutSimple size={16} />
|
||||
</button>
|
||||
<div className="mt-3 text-zinc-400">
|
||||
<Icon size={16} />
|
||||
</div>
|
||||
<span
|
||||
className="mt-2 text-xs text-zinc-400"
|
||||
style={{
|
||||
writingMode: "vertical-rl",
|
||||
textOrientation: "mixed",
|
||||
maxHeight: "120px",
|
||||
overflow: "hidden",
|
||||
textOverflow: "ellipsis",
|
||||
}}
|
||||
>
|
||||
{artifact.title}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
ArrowLeft,
|
||||
ArrowsIn,
|
||||
ArrowsOut,
|
||||
Copy,
|
||||
DownloadSimple,
|
||||
Minus,
|
||||
X,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { SourceToggle } from "./SourceToggle";
|
||||
|
||||
interface Props {
|
||||
artifact: ArtifactRef;
|
||||
classification: ArtifactClassification;
|
||||
canGoBack: boolean;
|
||||
isMaximized: boolean;
|
||||
isSourceView: boolean;
|
||||
hasSourceToggle: boolean;
|
||||
mobile?: boolean;
|
||||
canCopy?: boolean;
|
||||
onBack: () => void;
|
||||
onClose: () => void;
|
||||
onMinimize: () => void;
|
||||
onMaximize: () => void;
|
||||
onRestore: () => void;
|
||||
onCopy: () => void;
|
||||
onDownload: () => void;
|
||||
onSourceToggle: (isSource: boolean) => void;
|
||||
}
|
||||
|
||||
function HeaderButton({
|
||||
onClick,
|
||||
title,
|
||||
children,
|
||||
}: {
|
||||
onClick: () => void;
|
||||
title: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClick}
|
||||
title={title}
|
||||
aria-label={title}
|
||||
className="rounded p-1.5 text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export function ArtifactPanelHeader({
|
||||
artifact,
|
||||
classification,
|
||||
canGoBack,
|
||||
isMaximized,
|
||||
isSourceView,
|
||||
hasSourceToggle,
|
||||
mobile,
|
||||
canCopy = true,
|
||||
onBack,
|
||||
onClose,
|
||||
onMinimize,
|
||||
onMaximize,
|
||||
onRestore,
|
||||
onCopy,
|
||||
onDownload,
|
||||
onSourceToggle,
|
||||
}: Props) {
|
||||
const Icon = classification.icon;
|
||||
|
||||
return (
|
||||
<div className="sticky top-0 z-10 flex items-center gap-2 border-b border-zinc-200 bg-white px-3 py-2">
|
||||
{/* Left section */}
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2">
|
||||
{canGoBack && (
|
||||
<HeaderButton onClick={onBack} title="Back">
|
||||
<ArrowLeft size={16} />
|
||||
</HeaderButton>
|
||||
)}
|
||||
<Icon size={16} className="shrink-0 text-zinc-400" />
|
||||
<span className="truncate text-sm font-medium text-zinc-900">
|
||||
{artifact.title}
|
||||
</span>
|
||||
<span
|
||||
className={cn(
|
||||
"shrink-0 rounded-full px-2 py-0.5 text-xs font-medium",
|
||||
artifact.origin === "user-upload"
|
||||
? "bg-blue-50 text-blue-600"
|
||||
: "bg-violet-50 text-violet-600",
|
||||
)}
|
||||
>
|
||||
{classification.label}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Right section */}
|
||||
<div className="flex items-center gap-1">
|
||||
{hasSourceToggle && (
|
||||
<SourceToggle isSourceView={isSourceView} onToggle={onSourceToggle} />
|
||||
)}
|
||||
{canCopy && (
|
||||
<HeaderButton onClick={onCopy} title="Copy">
|
||||
<Copy size={16} />
|
||||
</HeaderButton>
|
||||
)}
|
||||
<HeaderButton onClick={onDownload} title="Download">
|
||||
<DownloadSimple size={16} />
|
||||
</HeaderButton>
|
||||
{!mobile && (
|
||||
<>
|
||||
<HeaderButton onClick={onMinimize} title="Minimize">
|
||||
<Minus size={16} />
|
||||
</HeaderButton>
|
||||
{isMaximized ? (
|
||||
<HeaderButton onClick={onRestore} title="Restore">
|
||||
<ArrowsIn size={16} />
|
||||
</HeaderButton>
|
||||
) : (
|
||||
<HeaderButton onClick={onMaximize} title="Maximize">
|
||||
<ArrowsOut size={16} />
|
||||
</HeaderButton>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<HeaderButton onClick={onClose} title="Close">
|
||||
<X size={16} />
|
||||
</HeaderButton>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user