diff --git a/.claude/skills/orchestrate/SKILL.md b/.claude/skills/orchestrate/SKILL.md new file mode 100644 index 0000000000..58a7e45c79 --- /dev/null +++ b/.claude/skills/orchestrate/SKILL.md @@ -0,0 +1,709 @@ +--- +name: orchestrate +description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work." +user-invocable: true +argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'" +metadata: + author: autogpt-team + version: "6.0.0" +--- + +# Orchestrate — Agent Fleet Supervisor + +One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts. + +## Scripts + +```bash +SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts +STATE_FILE=~/.claude/orchestrator-state.json +``` + +| Script | Purpose | +|---|---| +| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line | +| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** | +| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch | +| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification | +| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. | +| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout | +| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees | +| `status.sh` | Print fleet status + live pane commands | +| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array | +| `classify-pane.sh WINDOW` | Classify one pane state | + +## Supervision model + +``` +Orchestrating Claude (this Claude session — IS the supervisor) + └── Reads pane output, checks CI, intervenes with targeted guidance + run-loop.sh (separate tmux window, every 30s) + └── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE +``` + +**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems. + +**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes). + +## Checkpoint protocol + +Agents output checkpoints as they complete each required step: + +``` +CHECKPOINT: +``` + +Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically. + +## Worktree lifecycle + +```text +spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running + ↓ + CHECKPOINT: (as steps complete) + ↓ + ORCHESTRATOR:DONE + ↓ + verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED + ↓ + state → "done", notify, window KEPT OPEN + ↓ + user/orchestrator explicitly requests recycle + ↓ + recycle-agent.sh → spare/N (free again) +``` + +**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle. + +**To resume a done or crashed session:** +```bash +# Resume by stored session ID (preferred — exact session, full context) +claude --resume SESSION_ID --permission-mode bypassPermissions + +# Or resume most recent session in that worktree directory +cd /path/to/worktree && claude --continue --permission-mode bypassPermissions +``` + +**To manually recycle when ready:** +```bash +bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N +# Then update state: +jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \ + ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +## State file (`~/.claude/orchestrator-state.json`) + +Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`). + +```json +{ + "active": true, + "tmux_session": "autogpt1", + "idle_threshold_seconds": 300, + "loop_window": "autogpt1:5", + "repo": "Significant-Gravitas/AutoGPT", + "discord_webhook": "https://discord.com/api/webhooks/...", + "last_poll_at": 0, + "agents": [ + { + "window": "autogpt1:3", + "worktree": "AutoGPT6", + "worktree_path": "/path/to/AutoGPT6", + "spare_branch": "spare/6", + "branch": "feat/my-feature", + "objective": "Implement X and open a PR", + "pr_number": "12345", + "session_id": "550e8400-e29b-41d4-a716-446655440000", + "steps": ["pr-address", "pr-test"], + "checkpoints": ["pr-address"], + "state": "running", + "last_output_hash": "", + "last_seen_at": 0, + "spawned_at": 0, + "idle_since": 0, + "revision_count": 0, + "last_rebriefed_at": 0 + } + ] +} +``` + +Top-level optional fields: +- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted. +- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var. + +Per-agent fields: +- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close. +- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam. + +Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated` + +`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet. + +## Serial /pr-test rule + +`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.** + +**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.** + +You (the orchestrating Claude) own the test queue: +1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub). +2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step. +3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially. +4. Feed results back to the relevant agent via `tmux send-keys`: + ```bash + tmux send-keys -t SESSION:WIN "Local tests for PR #N: . Fix any failures and push, then output ORCHESTRATOR:DONE." + sleep 0.3 + tmux send-keys -t SESSION:WIN Enter + ``` +5. Wait for CI to confirm green before marking the agent done. + +If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes. + +## Session restore (tested and confirmed) + +Agent sessions are saved to disk. To restore a closed or crashed session: + +```bash +# If session_id is in state (preferred): +NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}') +tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter + +# If no session_id (use --continue for most recent session in that directory): +tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter +``` + +`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file: + +```bash +jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \ + '(.agents[] | select(.window == $old)).window = $new' \ + ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +## Intent → action mapping + +Match the user's message to one of these intents: + +| The user says something like… | What to do | +|---|---| +| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output | +| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output | +| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below | +| "add task: …", "add one more agent for …" | See **Adding an agent** below | +| "stop", "shut down", "pause the fleet" | See **Stopping** below | +| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions | +| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly | + +When the intent is ambiguous, show capacity first and ask what tasks to run. + +## Spawning agents + +### 1. Resolve tmux session + +```bash +tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null +``` + +Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`. + +### 2. Show available capacity + +```bash +bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel) +``` + +### 3. Collect tasks from the user + +For each task, gather: +- **objective** — what to do (e.g. "implement feature X and open a PR") +- **branch name** — e.g. `feat/my-feature` (derive from objective if not given) +- **pr_number** — GitHub PR number if working on an existing PR (for verification) +- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective + +Ask for `idle_threshold_seconds` only if the user mentions it (default: 300). + +Never ask the user to specify a worktree — auto-assign from `find-spare.sh`. + +### 4. Spawn one agent per task + +```bash +# Get ordered list of spare worktrees +SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel)) + +# For each task, take the next spare line: +WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}') +SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}') + +# With PR number and required steps: +WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test") + +# Without PR (new work): +WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE") +``` + +Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it: + +```bash +# Derive repo from git remote (used by verify-complete.sh + supervisor) +REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "") + +jq -n \ + --arg session "$SESSION" \ + --arg repo "$REPO" \ + --argjson threshold 300 \ + '{active:true, tmux_session:$session, idle_threshold_seconds:$threshold, + repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \ + > ~/.claude/orchestrator-state.json +``` + +Optionally add a Discord webhook for completion notifications: +```bash +jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \ + > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself. + +### 5. Start the mechanical babysitter + +```bash +LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}') +LOOP_WINDOW="${SESSION}:${LOOP_WIN}" +tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter + +jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \ + > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +### 6. Begin supervising directly in this conversation + +You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window. + +## Adding an agent + +Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user. + +## Supervisor duties (YOUR job, every 2-3 min in this conversation) + +You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window. + +### Poll loop mechanism + +You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement: + +1. Start each poll with `run_in_background: true` + a sleep before the work: + ```bash + sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40 + # + similar for each active window + ``` +2. When the background job notifies you, read the pane output and take action. +3. Immediately schedule the next background poll — this keeps the loop alive. +4. Stop scheduling when all agents are done/escalated. + +**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead. + +### Each poll: what to check + +```bash +# 1. Read state +cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}' + +# 2. For each running/stuck/idle agent, capture pane +tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60 +``` + +For each agent, decide: + +| What you see | Action | +|---|---| +| Spinner / tools running | Do nothing — agent is working | +| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state | +| Stuck in error loop | Send targeted fix with exact error + solution | +| Waiting for input / question | Answer and unblock via `tmux send-keys` | +| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing | +| GitHub abuse rate limit error | Nudge: "Wait 60 seconds then continue posting replies with sleep 3 between each" | +| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` | +| `ORCHESTRATOR:DONE` in output | Query GraphQL for actual unresolved count. If >0, re-brief. If 0, run `verify-complete.sh` | + +**Poll all windows from state, not from memory.** Before each poll, run: +```bash +jq -r '.agents[] | select(.state | test("running|idle|stuck|waiting_approval|pending_evaluation")) | .window' ~/.claude/orchestrator-state.json +``` +and capture every window listed. If you manually added a window outside spawn-agent.sh, ensure it's in the state file first. + +### RUNNING count includes waiting_approval agents + +The `RUNNING` count from run-loop.sh includes agents in `waiting_approval` state (they match the regex `running|stuck|waiting_approval|idle`). This means a fleet that is only `waiting_approval` still shows RUNNING > 0 in the log — it does **not** mean agents are actively working. + +When you see `RUNNING > 0` in the run-loop log but suspect agents are actually blocked, check state directly: +```bash +jq '.agents[] | {window, state, worktree}' ~/.claude/orchestrator-state.json +``` +A count of `running=1 waiting=1` in the log actually means one agent is waiting for approval — the orchestrator should check and approve, not wait. + +### State file staleness recovery + +The state file is written by scripts but can drift from reality when windows are closed, sessions expire, or the orchestrator restarts across conversations. + +**Signs of stale state:** +- `loop_window` points to a window that no longer exists in the tmux session +- An agent's `state` is `running` but tmux window is closed or shows a shell prompt (not claude) +- `last_seen_at` is hours old but state still says `running` + +**Recovery steps:** + +1. **Verify actual tmux windows:** +```bash +tmux list-windows -t SESSION -F '#{window_index}: #{window_name} (#{pane_current_command})' +``` + +2. **Cross-reference with state file:** +```bash +jq -r '.agents[] | "\(.window) \(.state) \(.worktree)"' ~/.claude/orchestrator-state.json +``` + +3. **Fix stale entries:** +```bash +# Agent window closed — mark idle so run-loop.sh will restart it +jq --arg w "SESSION:WIN" '(.agents[] | select(.window==$w)).state = "idle"' \ + ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json + +# loop_window gone — kill the stale reference, then restart run-loop.sh +jq '.loop_window = null' ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}') +LOOP_WINDOW="${SESSION}:${LOOP_WIN}" +tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter +jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \ + > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +4. **After any state repair, re-run `status.sh` to confirm coherence before resuming supervision.** + +### Strict ORCHESTRATOR:DONE gate + +`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it: + +**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits. + +```bash +SKILLS_DIR=~/.claude/orchestrator/scripts +bash $SKILLS_DIR/verify-complete.sh SESSION:WIN +``` + +If it passes → run-loop.sh will recycle the window automatically. No manual action needed. +If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this. + +### Re-brief a stalled agent + +**Before sending any nudge, verify the pane is at an idle ❯ prompt.** Sending text into a still-processing pane produces stuck `[Pasted text +N lines]` that the agent never sees. + +Check: +```bash +tmux capture-pane -t SESSION:WIN -p 2>/dev/null | tail -5 +``` +If the last line shows a spinner (✳✽✢✶·), `Running…`, or no `❯` — wait 10–15s and check again before sending. + +```bash +OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json) +PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json) +tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient." +sleep 0.3 +tmux send-keys -t SESSION:WIN Enter +``` + +If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool." + +## Self-recovery protocol (agents) + +spawn-agent.sh automatically includes this instruction in every objective: + +> If your context compacts and you lose track of what to do, run: +> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'` +> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient. +> Output each completed step as `CHECKPOINT:` on its own line. + +## Passing images and screenshots to agents + +`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups): + +1. **Save the image to a temp file** with a stable path: + ```bash + # If the user drags in a screenshot or you receive a file path: + IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png" + cp "$USER_PROVIDED_PATH" "$IMAGE_PATH" + ``` + +2. **Reference the path in the objective string**: + ```bash + OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design." + ``` + +3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly. + +**Rule**: always use `/tmp/orchestrator-context-.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image. + +--- + +## Orchestrator final evaluation (YOU decide, not the script) + +`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job. + +When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done: + +### 1. Run /pr-test (required, serialized, use TodoWrite to queue) + +`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent. + +**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:** +``` +- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/NNNN — +- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/MMMM — +``` +Run one at a time. Check off as you go. + +``` +/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER +``` + +**/pr-test can be lazy** — if it gives vague output, re-run with full context: + +``` +/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER +Context: This PR implements . Key files: . +Please verify: . +``` + +Only one `/pr-test` at a time — they share ports and DB. + +### /pr-test result evaluation + +**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`. + +| `/pr-test` result | Action | +|---|---| +| All headline scenarios **PASS** | Proceed to evaluation step 2 | +| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below | +| Any headline scenario **FAIL** | Re-brief the agent immediately | + +**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it. + +**When any headline scenario is PARTIAL or FAIL:** + +1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE` +2. Re-brief the agent with the specific scenario that failed and what was missing: + ```bash + tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test." + sleep 0.3 + tmux send-keys -t SESSION:WIN Enter + ``` +3. Set state back to `running`: + ```bash + jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \ + ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json + ``` +4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch + +**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure. + +> **Why this matters**: A PR was once wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking. + +### 2. Do your own evaluation + +1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done? +2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes? +3. **Check CI run names** — any suspicious retries that shouldn't have passed? +4. **Check the PR description** — title, summary, test plan complete? + +### 3. Decide + +- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed +- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above) +- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running` + +**Never mark done based purely on script output.** You hold the full objective context; the script does not. + +```bash +# Mark done after your positive evaluation: +jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \ + ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json +``` + +## When to stop the fleet + +Stop the fleet (`active = false`) when **all** of the following are true: + +| Check | How to verify | +|---|---| +| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 | +| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR | +| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state | +| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored | +| No agents are `escalated` without human review | If any are escalated, surface to user first | + +**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop. + +**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running. + +```bash +# Graceful stop +jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \ + && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json + +LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json) +[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true +``` + +Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress. + +## tmux send-keys pattern + +**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent. + +```bash +# CORRECT — text then Enter separately +tmux send-keys -t "$WINDOW" "your long message here" +sleep 0.3 +tmux send-keys -t "$WINDOW" Enter + +# WRONG — Enter may fire before text is buffered +tmux send-keys -t "$WINDOW" "your long message here" Enter +``` + +Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag. + +--- + +## Protected worktrees + +Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself: + +| Worktree | Protected branch | Why | +|---|---|---| +| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. | + +**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes. + +When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again. + +--- + +## Thread resolution integrity (critical) + +**Agents MUST NOT resolve review threads via GraphQL unless a real code fix has been committed and pushed first.** + +This is the most common failure mode: agents call `resolveReviewThread` to make unresolved counts drop without actually fixing anything. This produces a false "done" signal that gets past verify-complete.sh. + +**The only valid resolution sequence:** +1. Read the thread and understand what it's asking +2. Make the actual code change +3. `git commit` and `git push` +4. Reply to the thread with the commit SHA (e.g. "Fixed in `abc1234`") +5. THEN call `resolveReviewThread` + +**The supervisor must verify actual thread counts via GraphQL** — never trust an agent's claim of "0 unresolved." After any agent's ORCHESTRATOR:DONE, always run: + +```bash +# Step 1: get total count +TOTAL=$(gh api graphql -f query='{ repository(owner: "OWNER", name: "REPO") { pullRequest(number: PR) { reviewThreads { totalCount } } } }' \ + | jq '.data.repository.pullRequest.reviewThreads.totalCount') +echo "Total threads: $TOTAL" + +# Step 2: paginate all pages and count unresolved +CURSOR=""; UNRESOLVED=0 +while true; do + AFTER=${CURSOR:+", after: \"$CURSOR\""} + PAGE=$(gh api graphql -f query="{ repository(owner: \"OWNER\", name: \"REPO\") { pullRequest(number: PR) { reviewThreads(first: 100${AFTER}) { pageInfo { hasNextPage endCursor } nodes { isResolved } } } } }") + UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') )) + HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage') + CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor') + [ "$HAS_NEXT" = "false" ] && break +done +echo "Unresolved: $UNRESOLVED" +``` + +If unresolved > 0, the agent is NOT done — re-brief with the actual count and the rule. + +**Include this in every agent objective:** +> IMPORTANT: Do NOT resolve any review thread via GraphQL unless the code fix is committed and pushed first. Fix the code → commit → push → reply with SHA → then resolve. Never resolve without a real commit. "Accepted" or "Acknowledged" replies are NOT resolutions — only real commits qualify. + +### Detecting fake resolutions + +When an agent claims "0 unresolved threads", query GitHub GraphQL yourself and also inspect how each thread was resolved. A resolved thread whose last comment is `"Acknowledged"`, `"Same as above"`, `"Accepted trade-off"`, or `"Deferred"` — with no commit SHA — is a fake resolution. + +To spot these, paginate all pages and collect resolved threads with missing SHA links: +```bash +# Paginate all pages — first:100 misses threads beyond page 1 on large PRs +CURSOR=""; FAKE_RESOLUTIONS="[]" +while true; do + AFTER=${CURSOR:+", after: \"$CURSOR\""} + PAGE=$(gh api graphql -f query=" + { + repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") { + pullRequest(number: PR_NUMBER) { + reviewThreads(first: 100${AFTER}) { + pageInfo { hasNextPage endCursor } + nodes { + isResolved + comments(last: 1) { + nodes { body author { login } } + } + } + } + } + } + }") + PAGE_FAKES=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] + | select(.isResolved == true) + | {body: .comments.nodes[0].body[:120], author: .comments.nodes[0].author.login} + | select(.body | test("Fixed in|Removed in|Addressed in") | not)]') + FAKE_RESOLUTIONS=$(echo "$FAKE_RESOLUTIONS $PAGE_FAKES" | jq -s 'add') + HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage') + CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor') + [ "$HAS_NEXT" = "false" ] && break +done +echo "$FAKE_RESOLUTIONS" +``` +Any resolved thread whose last comment does NOT contain `"Fixed in"`, `"Removed in"`, or `"Addressed in"` (with a commit link) should be investigated — either the agent falsely resolved it, or it was a genuine false positive that needs explanation. + +## GitHub abuse rate limits + +Two distinct rate limits exist with different recovery times: + +| Error | HTTP status | Cause | Recovery | +|---|---|---|---| +| `{"code":"abuse"}` in body | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. | +| `API rate limit exceeded` | 429 | Primary rate limit — too many read calls per hour | Wait until `X-RateLimit-Reset` timestamp | + +**Prevention:** Agents must add `sleep 3` between individual thread reply API calls. For >20 unresolved threads, increase to `sleep 5`. + +If you see a 403 `abuse` error from an agent's pane: +1. Nudge the agent: `"You hit a GitHub secondary rate limit (403). Stop all API writes. Wait 2 minutes, then resume with sleep 3 between each thread reply."` +2. Do NOT nudge again during the 2-minute wait — a second nudge restarts the clock. + +Add this to agent briefings when there are >20 unresolved threads: +> Post replies with `sleep 3` between each reply. If you hit a 403 abuse error, wait 2 minutes (not 60s — secondary limits take longer to clear) then continue. + +## Key rules + +1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file +2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output +3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell) +4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter +5. **Always `--permission-mode bypassPermissions`** on every spawn +6. **Escalate after 3 kicks** — mark `escalated`, surface to user +7. **Atomic state writes** — always write to `.tmp` then `mv` +8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate +9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling +10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence +11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux +12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling +13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare +14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-.png`, pass path in objective; agents read with the `Read` tool +15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings +16. **Poll ALL windows from state file** — never hardcode window count. Derive active windows dynamically: `jq -r '.agents[] | select(.state | test("running|idle|stuck")) | .window' ~/.claude/orchestrator-state.json`. If you added a window mid-session outside spawn-agent.sh, add it to the state file immediately. +20. **Orchestrator handles its own approvals** — when spawning a subagent to make edits (SKILL.md, scripts, config), review the diff yourself and approve/reject without surfacing it to the user. The user should never have to open a file to check the orchestrator's work. Use the Agent tool with `subagent_type: general-purpose` for drafting, then verify the result yourself before considering the task done. +17. **Update state file on re-task** — whenever an agent is re-tasked mid-session (objective changes, new PR assigned), update the state file record immediately so objectives stay accurate for re-briefing after compaction. +18. **No GraphQL resolveReviewThread without a commit** — see Thread resolution integrity above. This is rule #1 for pr-address work. +19. **Verify thread counts yourself** — after any agent claims "0 unresolved threads", query GitHub GraphQL directly before accepting. Never trust the agent's self-report. diff --git a/.claude/skills/orchestrate/scripts/capacity.sh b/.claude/skills/orchestrate/scripts/capacity.sh new file mode 100755 index 0000000000..1bbf376297 --- /dev/null +++ b/.claude/skills/orchestrate/scripts/capacity.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# capacity.sh — show fleet capacity: available spare worktrees + in-use agents +# +# Usage: capacity.sh [REPO_ROOT] +# REPO_ROOT defaults to the root worktree of the current git repo. +# +# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt) + +set -euo pipefail + +SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" +REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}" + +echo "=== Available (spare) worktrees ===" +if [ -n "$REPO_ROOT" ]; then + SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "") +else + SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "") +fi + +if [ -z "$SPARE" ]; then + echo " (none)" +else + while IFS= read -r line; do + [ -z "$line" ] && continue + echo " ✓ $line" + done <<< "$SPARE" +fi + +echo "" +echo "=== In-use worktrees ===" +if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then + IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \ + "$STATE_FILE" 2>/dev/null || echo "") + if [ -n "$IN_USE" ]; then + echo "$IN_USE" + else + echo " (none)" + fi +else + echo " (no active state file)" +fi diff --git a/.claude/skills/orchestrate/scripts/classify-pane.sh b/.claude/skills/orchestrate/scripts/classify-pane.sh new file mode 100755 index 0000000000..57504c72ce --- /dev/null +++ b/.claude/skills/orchestrate/scripts/classify-pane.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +# classify-pane.sh — Classify the current state of a tmux pane +# +# Usage: classify-pane.sh +# tmux-target: e.g. "work:0", "work:1.0" +# +# Output (stdout): JSON object: +# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." } +# +# Exit codes: 0=ok, 1=error (invalid target or tmux window not found) + +set -euo pipefail + +TARGET="${1:-}" + +if [ -z "$TARGET" ]; then + echo '{"state":"error","reason":"no target provided","pane_cmd":""}' + exit 1 +fi + +# Validate tmux target format: session:window or session:window.pane +if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then + echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}' + exit 1 +fi + +# Check session exists (use %%:* to extract session name from session:window) +if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then + echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}' + exit 1 +fi + +# Get the current foreground command in the pane +PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown") + +# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support) +RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "") +CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \ + | grep -v '^[[:space:]]*$' || true) + +# --- Check: explicit completion marker --- +# Must be on its own line (not buried in the objective text sent at spawn time). +if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then + jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}' + exit 0 +fi + +# --- Check: Claude Code approval prompt patterns --- +LAST_40=$(echo "$CLEAN" | tail -40) +APPROVAL_PATTERNS=( + "Do you want to proceed" + "Do you want to make this" + "\\[y/n\\]" + "\\[Y/n\\]" + "\\[n/Y\\]" + "Proceed\\?" + "Allow this command" + "Run bash command" + "Allow bash" + "Would you like" + "Press enter to continue" + "Esc to cancel" +) +for pattern in "${APPROVAL_PATTERNS[@]}"; do + if echo "$LAST_40" | grep -qiE "$pattern"; then + jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \ + '{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}' + exit 0 + fi +done + +# --- Check: shell prompt (claude has exited) --- +# If the foreground process is a shell (not claude/node), the agent has exited +case "$PANE_CMD" in + zsh|bash|fish|sh|dash|tcsh|ksh) + jq -n --arg cmd "$PANE_CMD" \ + '{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}' + exit 0 + ;; +esac + +# Agent is still running (claude/node/python is the foreground process) +jq -n --arg cmd "$PANE_CMD" \ + '{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}' +exit 0 diff --git a/.claude/skills/orchestrate/scripts/find-spare.sh b/.claude/skills/orchestrate/scripts/find-spare.sh new file mode 100755 index 0000000000..e374a41c9b --- /dev/null +++ b/.claude/skills/orchestrate/scripts/find-spare.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# find-spare.sh — list worktrees on spare/N branches (free to use) +# +# Usage: find-spare.sh [REPO_ROOT] +# REPO_ROOT defaults to the root worktree containing the current git repo. +# +# Output (stdout): one line per available worktree: "PATH BRANCH" +# e.g.: /Users/me/Code/AutoGPT3 spare/3 + +set -euo pipefail + +REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}" +if [ -z "$REPO_ROOT" ]; then + echo "Error: not inside a git repo and no REPO_ROOT provided" >&2 + exit 1 +fi + +git -C "$REPO_ROOT" worktree list --porcelain \ + | awk ' + /^worktree / { path = substr($0, 10) } + /^branch / { branch = substr($0, 8); print path " " branch } + ' \ + | { grep -E " refs/heads/spare/[0-9]+$" || true; } \ + | sed 's|refs/heads/||' diff --git a/.claude/skills/orchestrate/scripts/notify.sh b/.claude/skills/orchestrate/scripts/notify.sh new file mode 100755 index 0000000000..ace46cc152 --- /dev/null +++ b/.claude/skills/orchestrate/scripts/notify.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# notify.sh — send a fleet notification message +# +# Delivery order (first available wins): +# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook +# 2. macOS notification center — osascript (silent fail if unavailable) +# 3. Stdout only +# +# Usage: notify.sh MESSAGE +# Exit: always 0 (notification failure must not abort the caller) + +MESSAGE="${1:-}" +[ -z "$MESSAGE" ] && exit 0 + +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" + +# --- Resolve Discord webhook --- +WEBHOOK="${DISCORD_WEBHOOK_URL:-}" +if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then + WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "") +fi + +# --- Discord delivery --- +if [ -n "$WEBHOOK" ]; then + PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}') + curl -s -X POST "$WEBHOOK" \ + -H "Content-Type: application/json" \ + -d "$PAYLOAD" > /dev/null 2>&1 || true +fi + +# --- macOS notification center (silent if not macOS or osascript missing) --- +if command -v osascript &>/dev/null 2>&1; then + # Escape single quotes for AppleScript + SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g") + osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true +fi + +# Always print to stdout so run-loop.sh logs it +echo "$MESSAGE" +exit 0 diff --git a/.claude/skills/orchestrate/scripts/poll-cycle.sh b/.claude/skills/orchestrate/scripts/poll-cycle.sh new file mode 100755 index 0000000000..dafd307bf3 --- /dev/null +++ b/.claude/skills/orchestrate/scripts/poll-cycle.sh @@ -0,0 +1,257 @@ +#!/usr/bin/env bash +# poll-cycle.sh — Single orchestrator poll cycle +# +# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state, +# and outputs a JSON array of actions for Claude to take. +# +# Usage: poll-cycle.sh +# Output (stdout): JSON array of action objects +# [{ "window": "work:0", "action": "kick|approve|none", "state": "...", +# "worktree": "...", "objective": "...", "reason": "..." }] +# +# The state file is updated in-place (atomic write via .tmp). + +set -euo pipefail + +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" +SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLASSIFY="$SCRIPTS_DIR/classify-pane.sh" + +# Cross-platform md5: always outputs just the hex digest +md5_hash() { + if command -v md5sum &>/dev/null; then + md5sum | awk '{print $1}' + else + md5 | awk '{print $NF}' + fi +} + +# Clean up temp file on any exit (avoids stale .tmp if jq write fails) +trap 'rm -f "${STATE_FILE}.tmp"' EXIT + +# Ensure state file exists +if [ ! -f "$STATE_FILE" ]; then + echo '{"active":false,"agents":[]}' > "$STATE_FILE" +fi + +# Validate JSON upfront before any jq reads that run under set -e. +# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise +# abort the script at the ACTIVE read below without emitting any JSON output. +if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then + echo "State file parse error — check $STATE_FILE" >&2 + echo "[]" + exit 0 +fi + +ACTIVE=$(jq -r '.active // false' "$STATE_FILE") +if [ "$ACTIVE" != "true" ]; then + echo "[]" + exit 0 +fi + +NOW=$(date +%s) +IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE") + +ACTIONS="[]" +UPDATED_AGENTS="[]" + +# Read agents as newline-delimited JSON objects. +# jq exits non-zero when .agents[] has no matches on an empty array, which is valid — +# so we suppress that exit code and separately validate the file is well-formed JSON. +if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then + if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then + echo "State file parse error — check $STATE_FILE" >&2 + fi + echo "[]" + exit 0 +fi + +if [ -z "$AGENTS_JSON" ]; then + echo "[]" + exit 0 +fi + +while IFS= read -r agent; do + [ -z "$agent" ] && continue + + # Use // "" defaults so a single malformed field doesn't abort the whole cycle + WINDOW=$(echo "$agent" | jq -r '.window // ""') + WORKTREE=$(echo "$agent" | jq -r '.worktree // ""') + OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""') + STATE=$(echo "$agent" | jq -r '.state // "running"') + LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""') + IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0') + REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0') + + # Validate window format to prevent tmux target injection. + # Allow session:window (numeric or named) and session:window.pane + if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then + echo "Skipping agent with invalid window value: $WINDOW" >&2 + UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]') + continue + fi + + # Pass-through terminal-state agents + if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then + UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]') + continue + fi + + # Classify pane. + # classify-pane.sh always emits JSON before exit (even on error), so using + # "|| echo '...'" would concatenate two JSON objects when it exits non-zero. + # Use "|| true" inside the substitution so set -euo pipefail does not abort + # the poll cycle when classify exits with a non-zero status code. + CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true) + [ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}' + + PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state') + PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason') + + # Capture full pane output once — used for hash (stuck detection) and checkpoint parsing. + # Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed. + RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "") + + # --- Checkpoint tracking --- + # Parse any "CHECKPOINT:" lines the agent has output and merge into state file. + # The agent writes these as it completes each required step so verify-complete.sh can gate recycling. + EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []') + NEW_CHECKPOINTS_JSON="$EXISTING_CPS" + if [ -n "$RAW" ]; then + FOUND_CPS=$(echo "$RAW" \ + | grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \ + | sed 's/CHECKPOINT://' \ + | sort -u \ + | jq -R . | jq -s . 2>/dev/null || echo "[]") + NEW_CHECKPOINTS_JSON=$(jq -n \ + --argjson existing "$EXISTING_CPS" \ + --argjson found "$FOUND_CPS" \ + '($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS") + fi + + # Compute content hash for stuck-detection (only for running agents) + CURRENT_HASH="" + if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then + CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash) + fi + + NEW_STATE="$STATE" + NEW_IDLE_SINCE="$IDLE_SINCE" + NEW_REVISION_COUNT="$REVISION_COUNT" + ACTION="none" + REASON="$PANE_REASON" + + case "$PANE_STATE" in + complete) + # Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it. + # run-loop does NOT verify or notify; orchestrator's background poll picks this up. + NEW_STATE="pending_evaluation" + ACTION="complete" # run-loop logs it but takes no action + ;; + waiting_approval) + NEW_STATE="waiting_approval" + ACTION="approve" + ;; + idle) + # Agent process has exited — needs restart + NEW_STATE="idle" + ACTION="kick" + REASON="agent exited (shell is foreground)" + NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 )) + NEW_IDLE_SINCE=$NOW + if [ "$NEW_REVISION_COUNT" -ge 3 ]; then + NEW_STATE="escalated" + ACTION="none" + REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention" + fi + ;; + running) + # Clear idle_since only when transitioning from idle (agent was kicked and + # restarted). Do NOT reset for stuck — idle_since must persist across polls + # so STUCK_DURATION can accumulate and trigger escalation. + # Also update the local IDLE_SINCE so the hash-stability check below uses + # the reset value on this same poll, not the stale kick timestamp. + if [[ "$STATE" == "idle" ]]; then + NEW_IDLE_SINCE=0 + IDLE_SINCE=0 + fi + # Check if hash has been stable (agent may be stuck mid-task) + if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then + if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then + NEW_IDLE_SINCE=$NOW + else + STUCK_DURATION=$(( NOW - IDLE_SINCE )) + if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then + NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 )) + NEW_IDLE_SINCE=$NOW + if [ "$NEW_REVISION_COUNT" -ge 3 ]; then + NEW_STATE="escalated" + ACTION="none" + REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention" + else + NEW_STATE="stuck" + ACTION="kick" + REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)" + fi + fi + fi + else + # Only reset the idle timer when we have a valid hash comparison (pane + # capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed), + # preserve existing timers so stuck detection is not inadvertently reset. + if [ -n "$CURRENT_HASH" ]; then + NEW_STATE="running" + NEW_IDLE_SINCE=0 + fi + fi + ;; + error) + REASON="classify error: $PANE_REASON" + ;; + esac + + # Build updated agent record (ensure idle_since and revision_count are numeric) + # Use || true on each jq call so a malformed field skips this agent rather than + # aborting the entire poll cycle under set -e. + UPDATED_AGENT=$(echo "$agent" | jq \ + --arg state "$NEW_STATE" \ + --arg hash "$CURRENT_HASH" \ + --argjson now "$NOW" \ + --arg idle_since "$NEW_IDLE_SINCE" \ + --arg revision_count "$NEW_REVISION_COUNT" \ + --argjson checkpoints "$NEW_CHECKPOINTS_JSON" \ + '.state = $state + | .last_output_hash = (if $hash == "" then .last_output_hash else $hash end) + | .last_seen_at = $now + | .idle_since = ($idle_since | tonumber) + | .revision_count = ($revision_count | tonumber) + | .checkpoints = $checkpoints' 2>/dev/null) || { + echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2 + UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]') + continue + } + + UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]') + + # Add action if needed + if [ "$ACTION" != "none" ]; then + ACTION_OBJ=$(jq -n \ + --arg window "$WINDOW" \ + --arg action "$ACTION" \ + --arg state "$NEW_STATE" \ + --arg worktree "$WORKTREE" \ + --arg objective "$OBJECTIVE" \ + --arg reason "$REASON" \ + '{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}') + ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]') + fi + +done <<< "$AGENTS_JSON" + +# Atomic state file update +jq --argjson agents "$UPDATED_AGENTS" \ + --argjson now "$NOW" \ + '.agents = $agents | .last_poll_at = $now' \ + "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE" + +echo "$ACTIONS" diff --git a/.claude/skills/orchestrate/scripts/recycle-agent.sh b/.claude/skills/orchestrate/scripts/recycle-agent.sh new file mode 100755 index 0000000000..6d5e2fdc8f --- /dev/null +++ b/.claude/skills/orchestrate/scripts/recycle-agent.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch +# +# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH +# WINDOW — tmux target, e.g. autogpt1:3 +# WORKTREE_PATH — absolute path to the git worktree +# SPARE_BRANCH — branch to restore, e.g. spare/6 +# +# Stdout: one status line + +set -euo pipefail + +if [ $# -lt 3 ]; then + echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2 + exit 1 +fi + +WINDOW="$1" +WORKTREE_PATH="$2" +SPARE_BRANCH="$3" + +# Kill the tmux window (ignore error — may already be gone) +tmux kill-window -t "$WINDOW" 2>/dev/null || true + +# Restore to spare branch: abort any in-progress operation, then clean +git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true +git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true +git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null +git -C "$WORKTREE_PATH" clean -fd 2>/dev/null +git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH" + +echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)" diff --git a/.claude/skills/orchestrate/scripts/run-loop.sh b/.claude/skills/orchestrate/scripts/run-loop.sh new file mode 100755 index 0000000000..ff8b1a4df7 --- /dev/null +++ b/.claude/skills/orchestrate/scripts/run-loop.sh @@ -0,0 +1,215 @@ +#!/usr/bin/env bash +# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window) +# +# Handles ONLY two things that need no intelligence: +# idle → restart claude using --resume SESSION_ID (or --continue) to restore context +# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs +# +# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation, +# marking done, deciding to close windows — is the orchestrating Claude's job. +# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected; +# the orchestrator's background poll loop handles it from there. +# +# Usage: run-loop.sh +# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE + +set -euo pipefail + +# Copy scripts to a stable location outside the repo so they survive branch +# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree +# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the +# current branch). +_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts" +mkdir -p "$STABLE_SCRIPTS_DIR" +cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/" +chmod +x "$STABLE_SCRIPTS_DIR"/*.sh +SCRIPTS_DIR="$STABLE_SCRIPTS_DIR" + +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" +# Adaptive polling: starts at base interval, backs off up to POLL_IDLE_MAX when +# no agents need attention, resets on any activity or waiting_approval state. +POLL_INTERVAL="${POLL_INTERVAL:-30}" +POLL_IDLE_MAX=${POLL_IDLE_MAX:-300} +POLL_CURRENT=$POLL_INTERVAL + +# --------------------------------------------------------------------------- +# update_state WINDOW FIELD VALUE +# --------------------------------------------------------------------------- +update_state() { + local window="$1" field="$2" value="$3" + jq --arg w "$window" --arg f "$field" --arg v "$value" \ + '.agents |= map(if .window == $w then .[$f] = $v else . end)' \ + "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE" +} + +update_state_int() { + local window="$1" field="$2" value="$3" + jq --arg w "$window" --arg f "$field" --argjson v "$value" \ + '.agents |= map(if .window == $w then .[$f] = $v else . end)' \ + "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE" +} + +agent_field() { + jq -r --arg w "$1" --arg f "$2" \ + '.agents[] | select(.window == $w) | .[$f] // ""' \ + "$STATE_FILE" 2>/dev/null +} + +# --------------------------------------------------------------------------- +# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt +# --------------------------------------------------------------------------- +wait_for_prompt() { + local window="$1" + for i in $(seq 1 60); do + local cmd pane + cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "") + pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "") + if echo "$pane" | grep -q "Enter to confirm"; then + tmux send-keys -t "$window" Down Enter; sleep 2; continue + fi + [[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0 + sleep 1 + done + return 1 # timed out +} + +# --------------------------------------------------------------------------- +# wait_for_claude_idle WINDOW — wait up to 30s for Claude to reach idle ❯ prompt +# (no spinner or busy indicator visible in the last 3 lines of pane output) +# Returns 0 when idle, 1 on timeout. +# --------------------------------------------------------------------------- +wait_for_claude_idle() { + local window="$1" + local timeout="${2:-30}" + local elapsed=0 + while (( elapsed < timeout )); do + local cmd pane pane_tail + cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "") + pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "") + pane_tail=$(echo "$pane" | tail -3) + # Check full pane (not just tail) — 'Enter to confirm' dialog can scroll above last 3 lines. + # Do NOT reset elapsed — resetting allows an infinite loop if the dialog never clears. + if echo "$pane" | grep -q "Enter to confirm"; then + tmux send-keys -t "$window" Down Enter + sleep 2; (( elapsed += 2 )); continue + fi + # Must be running under node (Claude is live) + if [[ "$cmd" == "node" ]]; then + # Idle: ❯ prompt visible AND no spinner/busy text in last 3 lines + if echo "$pane_tail" | grep -q "❯" && \ + ! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then + return 0 + fi + fi + sleep 2 + (( elapsed += 2 )) + done + return 1 # timed out +} + +# --------------------------------------------------------------------------- +# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck +# --------------------------------------------------------------------------- +handle_kick() { + local window="$1" state="$2" + [[ "$state" != "idle" ]] && return # stuck agents handled by supervisor + + local worktree_path session_id + worktree_path=$(agent_field "$window" "worktree_path") + session_id=$(agent_field "$window" "session_id") + + echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session" + + # Wait for the shell prompt before typing — avoids sending into a still-draining pane + wait_for_claude_idle "$window" 30 \ + || echo "[$(date +%H:%M:%S)] KICK WARNING $window — pane still busy before resume, sending anyway" + + # Resume the exact session so the agent retains full context — no need to re-send objective + if [ -n "$session_id" ]; then + tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter + else + tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter + fi + + wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯" +} + +# --------------------------------------------------------------------------- +# handle_approve WINDOW — auto-approve dialogs that need no judgment +# --------------------------------------------------------------------------- +handle_approve() { + local window="$1" + local pane_tail + pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "") + + # Settings error dialog at startup + if echo "$pane_tail" | grep -q "Enter to confirm"; then + echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error" + tmux send-keys -t "$window" Down Enter + return + fi + + # Numbered-option dialog (e.g. "Do you want to make this edit?") + # ❯ is already on option 1 (Yes) — Enter confirms it + if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then + echo "[$(date +%H:%M:%S)] APPROVE edit $window" + tmux send-keys -t "$window" "" Enter + return + fi + + # y/n prompt for safe operations + if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then + echo "[$(date +%H:%M:%S)] APPROVE safe $window" + tmux send-keys -t "$window" "y" Enter + return + fi + + # Anything else — supervisor handles it, just log + echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle" +} + +# --------------------------------------------------------------------------- +# Main loop +# --------------------------------------------------------------------------- +echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll ${POLL_INTERVAL}s→${POLL_IDLE_MAX}s adaptive)" +echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)" +echo "---" + +while true; do + if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then + echo "[$(date +%H:%M:%S)] active=false — exiting." + exit 0 + fi + + ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]") + KICKED=0; DONE=0 + + while IFS= read -r action; do + [ -z "$action" ] && continue + WINDOW=$(echo "$action" | jq -r '.window // ""') + ACTION=$(echo "$action" | jq -r '.action // ""') + STATE=$(echo "$action" | jq -r '.state // ""') + + case "$ACTION" in + kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;; + approve) handle_approve "$WINDOW" || true ;; + complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles + esac + done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true) + + RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \ + "$STATE_FILE" 2>/dev/null || echo 0) + + # Adaptive backoff: reset to base on activity or waiting_approval agents; back off when truly idle + WAITING=$(jq '[.agents[] | select(.state == "waiting_approval")] | length' "$STATE_FILE" 2>/dev/null || echo 0) + if (( KICKED > 0 || DONE > 0 || WAITING > 0 )); then + POLL_CURRENT=$POLL_INTERVAL + else + POLL_CURRENT=$(( POLL_CURRENT + POLL_CURRENT / 2 + 1 )) + (( POLL_CURRENT > POLL_IDLE_MAX )) && POLL_CURRENT=$POLL_IDLE_MAX + fi + + echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled (next in ${POLL_CURRENT}s)" + sleep "$POLL_CURRENT" +done diff --git a/.claude/skills/orchestrate/scripts/spawn-agent.sh b/.claude/skills/orchestrate/scripts/spawn-agent.sh new file mode 100755 index 0000000000..7c565a525d --- /dev/null +++ b/.claude/skills/orchestrate/scripts/spawn-agent.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash +# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task +# +# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...] +# SESSION — tmux session name, e.g. autogpt1 +# WORKTREE_PATH — absolute path to the git worktree +# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle) +# NEW_BRANCH — task branch to create, e.g. feat/my-feature +# OBJECTIVE — task description sent to the agent +# PR_NUMBER — (optional) GitHub PR number for completion verification +# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test +# +# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this) +# Exit non-zero on failure. + +set -euo pipefail + +if [ $# -lt 5 ]; then + echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2 + exit 1 +fi + +SESSION="$1" +WORKTREE_PATH="$2" +SPARE_BRANCH="$3" +NEW_BRANCH="$4" +OBJECTIVE="$5" +PR_NUMBER="${6:-}" +STEPS=("${@:7}") +WORKTREE_NAME=$(basename "$WORKTREE_PATH") +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" + +# Generate a stable session ID so this agent's Claude session can always be resumed: +# claude --resume $SESSION_ID --permission-mode bypassPermissions +SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())") + +# Create (or switch to) the task branch +git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \ + || git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH" + +# Open a new named tmux window; capture its numeric index +WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}') +WINDOW="${SESSION}:${WIN_IDX}" + +# Append the initial agent record to the state file so subsequent jq updates find it. +# This must happen before the pr_number/steps update below. +if [ -f "$STATE_FILE" ]; then + NOW=$(date +%s) + jq --arg window "$WINDOW" \ + --arg worktree "$WORKTREE_NAME" \ + --arg worktree_path "$WORKTREE_PATH" \ + --arg spare_branch "$SPARE_BRANCH" \ + --arg branch "$NEW_BRANCH" \ + --arg objective "$OBJECTIVE" \ + --arg session_id "$SESSION_ID" \ + --argjson now "$NOW" \ + '.agents += [{ + "window": $window, + "worktree": $worktree, + "worktree_path": $worktree_path, + "spare_branch": $spare_branch, + "branch": $branch, + "objective": $objective, + "session_id": $session_id, + "state": "running", + "checkpoints": [], + "last_output_hash": "", + "last_seen_at": $now, + "spawned_at": $now, + "idle_since": 0, + "revision_count": 0, + "last_rebriefed_at": 0 + }]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE" +fi + +# Store pr_number + steps in state file if provided (enables verify-complete.sh). +# The agent record was appended above so the jq select now finds it. +if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then + if [ "${#STEPS[@]}" -gt 0 ]; then + STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .) + else + STEPS_JSON='[]' + fi + jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \ + '.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \ + "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE" +fi + +# Launch claude with a stable session ID so it can always be resumed after a crash: +# claude --resume SESSION_ID --permission-mode bypassPermissions +tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter + +# wait_for_claude_idle — poll until the pane shows idle ❯ with no spinner in the last 3 lines. +# Returns 0 when idle, 1 on timeout. +_wait_idle() { + local window="$1" timeout="${2:-60}" elapsed=0 + while (( elapsed < timeout )); do + local cmd pane_tail + cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "") + pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "") + pane_tail=$(echo "$pane" | tail -3) + # Check full pane (not just tail) — 'Enter to confirm' dialog can appear above the last 3 lines + if echo "$pane" | grep -q "Enter to confirm"; then + tmux send-keys -t "$window" Down Enter + sleep 2; (( elapsed += 2 )); continue + fi + if [[ "$cmd" == "node" ]] && \ + echo "$pane_tail" | grep -q "❯" && \ + ! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then + return 0 + fi + sleep 2; (( elapsed += 2 )) + done + return 1 +} + +# Wait up to 60s for claude to be fully interactive and idle (❯ visible, no spinner). +if ! _wait_idle "$WINDOW" 60; then + echo "[spawn-agent] WARNING: timed out waiting for idle ❯ prompt on $WINDOW — sending objective anyway" >&2 +fi + +# Send the task. Split text and Enter — if combined, Enter can fire before the string +# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent. +tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:. When ALL steps are done, output ORCHESTRATOR:DONE on its own line." +sleep 0.3 +tmux send-keys -t "$WINDOW" Enter + +# Only output the window address — nothing else (callers parse this) +echo "$WINDOW" diff --git a/.claude/skills/orchestrate/scripts/status.sh b/.claude/skills/orchestrate/scripts/status.sh new file mode 100755 index 0000000000..d1b191c05f --- /dev/null +++ b/.claude/skills/orchestrate/scripts/status.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# status.sh — print orchestrator status: state file summary + live tmux pane commands +# +# Usage: status.sh +# Reads: ~/.claude/orchestrator-state.json + +set -euo pipefail + +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" + +if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then + echo "No orchestrator state found at $STATE_FILE" + exit 0 +fi + +# Header: active status, session, thresholds, last poll +jq -r ' + "=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===", + "Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s", + "Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)", + "" +' "$STATE_FILE" + +# Each agent: state, window, worktree/branch, truncated objective +AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE") +if [ "$AGENT_COUNT" -eq 0 ]; then + echo " (no agents registered)" +else + jq -r ' + .agents[] | + " [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)", + " \(.objective // "" | .[0:70])" + ' "$STATE_FILE" +fi + +echo "" + +# Live pane_current_command for non-done agents +while IFS= read -r WINDOW; do + [ -z "$WINDOW" ] && continue + CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable") + echo " $WINDOW live: $CMD" +done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true) diff --git a/.claude/skills/orchestrate/scripts/verify-complete.sh b/.claude/skills/orchestrate/scripts/verify-complete.sh new file mode 100644 index 0000000000..55ddfc18c6 --- /dev/null +++ b/.claude/skills/orchestrate/scripts/verify-complete.sh @@ -0,0 +1,180 @@ +#!/usr/bin/env bash +# verify-complete.sh — verify a PR task is truly done before marking the agent done +# +# Check order matters: +# 1. Checkpoints — did the agent do all required steps? +# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait) +# 3. CI passing — no failures (agent must fix before done) +# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work) +# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included +# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included +# +# Usage: verify-complete.sh WINDOW +# Exit 0 = verified complete; exit 1 = not complete (stderr has reason) + +set -euo pipefail + +WINDOW="$1" +STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}" + +PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null) +STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true) +CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true) +WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null) +BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null) +SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0") + +# No PR number = cannot verify +if [ -z "$PR_NUMBER" ]; then + echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2 + exit 1 +fi + +# --- Check 1: all required steps are checkpointed --- +MISSING="" +while IFS= read -r step; do + [ -z "$step" ] && continue + if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then + MISSING="$MISSING $step" + fi +done <<< "$STEPS" + +if [ -n "$MISSING" ]; then + echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2 + exit 1 +fi + +# Resolve repo for all GitHub checks below +REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "") +if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then + REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \ + | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "") +fi + +if [ -z "$REPO" ]; then + echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2 + echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)" + exit 0 +fi + +CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]") + +# --- Check 2: CI fully complete — no pending checks --- +# Pending checks MUST finish before we check threads/reviews: +# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs. +PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0") +if [ "$PENDING" -gt 0 ]; then + PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \ + | jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown") + echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2 + exit 1 +fi + +# --- Check 3: CI passing — no failures --- +FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0") +if [ "$FAILING" -gt 0 ]; then + FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \ + | jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown") + echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2 + exit 1 +fi + +# --- Check 4: a new CI run was triggered AFTER the agent spawned --- +if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then + LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \ + --json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""') + if [ -n "$LATEST_RUN_AT" ]; then + if date --version >/dev/null 2>&1; then + LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0") + else + LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0") + fi + if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then + echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2 + exit 1 + fi + fi +fi + +OWNER=$(echo "$REPO" | cut -d/ -f1) +REPONAME=$(echo "$REPO" | cut -d/ -f2) + +# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) --- +UNRESOLVED=$(gh api graphql -f query=" + { repository(owner: \"${OWNER}\", name: \"${REPONAME}\") { + pullRequest(number: ${PR_NUMBER}) { + reviewThreads(first: 50) { nodes { isResolved } } + } + } + } +" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0") + +if [ "$UNRESOLVED" -gt 0 ]; then + echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2 + exit 1 +fi + +# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) --- +# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted. +# Stale reviews (pre-dating the fixing commits) should not block verification. +# +# Fetch commits and latestReviews in a single call and fail closed — if gh fails, +# treat that as NOT COMPLETE rather than silently passing. +# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded +# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved. +# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any +# PR activity (bot comments, label changes) which would create false negatives. +PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \ + --json commits,latestReviews 2>/dev/null) || { + echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2 + exit 1 +} + +LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA") +CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA") + +BLOCKING_CHANGES_REQUESTED=0 +BLOCKING_REQUESTERS="" + +if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then + if date --version >/dev/null 2>&1; then + LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0") + else + LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0") + fi + + while IFS= read -r review; do + [ -z "$review" ] && continue + REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""') + REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"') + if [ -z "$REVIEW_DATE" ]; then + # No submission date — treat as fresh (conservative: blocks verification) + BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 )) + BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}" + else + if date --version >/dev/null 2>&1; then + REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0") + else + REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0") + fi + if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then + # Review was submitted AFTER latest commit — still fresh, blocks verification + BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 )) + BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}" + fi + # Review submitted BEFORE latest commit — stale, skip + fi + done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')" +else + # No commit date or no changes_requested — check raw count as fallback + BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0") + BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown") +fi + +if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then + echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2 + exit 1 +fi + +echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED" +exit 0 diff --git a/.claude/skills/pr-address/SKILL.md b/.claude/skills/pr-address/SKILL.md index a0c4690454..cb730f9ed1 100644 --- a/.claude/skills/pr-address/SKILL.md +++ b/.claude/skills/pr-address/SKILL.md @@ -25,43 +25,102 @@ Understand the **Why / What / How** before addressing comments — you need cont gh pr view {N} --json body --jq '.body' ``` +> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks. + ## Fetch comments (all sources) ### 1. Inline review threads — GraphQL (primary source of actionable items) -Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed. +> ⚠️ **WARNING — PAGINATE ALL PAGES BEFORE ADDRESSING ANYTHING** +> +> `reviewThreads(first: 100)` returns at most 100 threads per page AND returns threads **oldest-first**. On a PR with many review cycles (e.g. 373 threads), the oldest 100–200 threads are from past cycles and are **all already resolved**. Filtering client-side with `select(.isResolved == false)` on page 1 therefore yields **0 results** — even though pages 2–4 contain many unresolved threads from recent review cycles. +> +> **This is the most common failure mode:** agent fetches page 1, sees 0 unresolved after filtering, stops pagination, reports "done" — while hundreds of unresolved threads sit on later pages. +> +> One observed PR had 142 total threads: page 1 returned 0 unresolved (all old/resolved), while pages 2–3 had 111 unresolved. Another with 373 threads across 4 pages also had page 1 entirely resolved. +> +> **The rule: ALWAYS paginate to `hasNextPage == false` regardless of the per-page unresolved count. Never stop early because a page returns 0 unresolved.** + +**Step 1 — Fetch total count and sanity-check the newest threads:** ```bash +# Get total count and the newest 100 threads (last: 100 returns newest-first) gh api graphql -f query=' { repository(owner: "Significant-Gravitas", name: "AutoGPT") { pullRequest(number: {N}) { - reviewThreads(first: 100) { - pageInfo { hasNextPage endCursor } - nodes { - id - isResolved - path - comments(last: 1) { - nodes { databaseId body author { login } createdAt } + reviewThreads { totalCount } + newest: reviewThreads(last: 100) { + nodes { isResolved } + } + } + } +}' | jq '{ total: .data.repository.pullRequest.reviewThreads.totalCount, newest_unresolved: [.data.repository.pullRequest.newest.nodes[] | select(.isResolved == false)] | length }' +``` + +If `total > 100`, you have multiple pages — you **must** paginate all of them regardless of what `newest_unresolved` shows. The `last: 100` check is a sanity signal only; the full loop below is mandatory. + +**Step 2 — Collect all unresolved thread IDs across all pages:** + +```bash +# Accumulate all unresolved threads — loop until hasNextPage == false +CURSOR="" +ALL_THREADS="[]" +while true; do + AFTER=${CURSOR:+", after: \"$CURSOR\""} + PAGE=$(gh api graphql -f query=" + { + repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") { + pullRequest(number: {N}) { + reviewThreads(first: 100${AFTER}) { + pageInfo { hasNextPage endCursor } + nodes { + id + isResolved + path + line + comments(last: 1) { + nodes { databaseId body author { login } } + } } } } } - } -}' + }") + # Append unresolved nodes from this page + PAGE_THREADS=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)]') + ALL_THREADS=$(echo "$ALL_THREADS $PAGE_THREADS" | jq -s 'add') + HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage') + CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor') + [ "$HAS_NEXT" = "false" ] && break +done + +# Reverse so newest threads (last pages) are addressed first — GitHub returns oldest-first +# and the most recent review cycle's comments are the ones blocking approval. +ALL_THREADS=$(echo "$ALL_THREADS" | jq 'reverse') + +echo "Total unresolved threads: $(echo "$ALL_THREADS" | jq 'length')" +echo "$ALL_THREADS" | jq '[.[] | {id, path, line, body: .comments.nodes[0].body[:200]}]' ``` -If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: ""` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false. +**Step 3 — Address every thread in `ALL_THREADS`, then resolve.** + +Only after this loop completes (all pages fetched, count confirmed) should you begin making fixes. + +> **Why reverse?** GraphQL returns threads oldest-first and exposes no `orderBy` option. A PR with 373 threads has ~4 pages; threads from the latest review cycle land on the last pages. Processing in reverse ensures the newest, most blocking comments are addressed first — the earlier pages mostly contain outdated threads from prior cycles. **Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls. +> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`). + ### 2. Top-level reviews — REST (MUST paginate) ```bash gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate ``` +> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.** + **CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page. Two things to extract: @@ -80,20 +139,71 @@ Two things to extract: gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate ``` +> **Already REST — unaffected by GraphQL rate limits.** + Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response. ## For each unaddressed comment -Address comments **one at a time**: fix → commit → push → inline reply → next. +**CRITICAL: The only valid sequence is fix → commit → push → reply → resolve. Never resolve a thread without a real code commit.** + +Resolving a thread via `resolveReviewThread` without an actual fix is the most common failure mode — it makes unresolved counts drop without any real change, producing a false "done" signal. If the issue was genuinely a false positive (no code change needed), reply explaining why and then resolve. Otherwise: + +Address comments **one at a time**: fix → commit → push → inline reply → resolve. 1. Read the referenced code, make the fix (or reply explaining why it's not needed) 2. Commit and push the fix 3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry): +Use a **markdown commit link** so GitHub renders it as a clickable reference. Always get the full SHA with `git rev-parse HEAD` **after** committing — never copy a SHA from a previous commit or hardcode one: + +```bash +FULL_SHA=$(git rev-parse HEAD) +gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies \ + -f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): " +``` + | Comment type | How to reply | |---|---| -| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in : "` | -| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in : "` | +| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): "` | +| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): "` | + +### What counts as a valid resolution + +Only two situations justify calling `resolveReviewThread`: + +1. **Real code fix**: you changed the code, committed + pushed, and replied with the SHA. The commit diff must actually address the concern — not just touch the same file. +2. **Genuine false positive**: the reviewer's concern does not apply to this code, and you can give a specific technical reason (e.g. "Not applicable — `sdk_cwd` is pre-validated by `_make_sdk_cwd()` which applies normpath + prefix assertion before reaching this point"). + +**Anti-patterns that look resolved but aren't — never do these:** +- `"Accepted, tracked as follow-up"` — a deferral, not a fix. The concern is still open. Do not resolve. +- `"Acknowledged"` or `"Same as above"` — these are acknowledgements, not fixes. Do not resolve. +- `"Fixed in abc1234"` where `abc1234` is a commit that doesn't actually change the flagged line/logic — dishonest. Verify `git show abc1234 -- path/to/file` changes the right thing before posting. +- Resolving without replying — the reviewer never sees what happened. + +When in doubt: if a code change is needed, make it. A deferred issue means the thread stays open until the follow-up PR is merged. + +## Codecov coverage + +Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green. + +### Running coverage locally + +**Backend** (from `autogpt_platform/backend/`): +```bash +poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing +``` + +**Frontend** (from `autogpt_platform/frontend/`): +```bash +pnpm vitest run --coverage +``` + +### When codecov/patch fails + +1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD` +2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend). +3. Run coverage locally to verify, commit, push. ## Format and commit @@ -119,6 +229,22 @@ Then commit and **push immediately** — never batch commits without pushing. Ea For backend commits in worktrees: `poetry run git commit` (pre-commit hooks). +## Coverage + +Codecov enforces patch coverage on new/changed lines — new code you write must be tested. Before pushing, verify you haven't left new lines uncovered: + +```bash +cd autogpt_platform/backend +poetry run pytest --cov=. --cov-report=term-missing {path/to/changed/module} +``` + +Look for lines marked `miss` — those are uncovered. Add tests for any new code you wrote as part of addressing comments. + +**Rules:** +- New code you add should have tests +- Don't remove existing tests when fixing comments +- If a reviewer asks you to delete code, also delete its tests, but verify coverage hasn't dropped on remaining lines + ## The loop ```text @@ -208,3 +334,162 @@ git push ``` 5. Restart the polling loop from the top — new commits reset CI status. + +## GitHub rate limits + +Three distinct rate limits exist — they have different causes, error shapes, and recovery times: + +| Error | HTTP code | Cause | Recovery | +|---|---|---|---| +| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. | +| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp | +| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. | + +**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`. + +### Detection + +The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID ` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down): + +```bash +gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...} +# Human-readable reset: +gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {} +``` + +Retry when `remaining > 0`. If you need to proceed sooner, sleep 2–5 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it. + +### What keeps working + +When GraphQL is unavailable (rate-limited or outage): + +- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe. +- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply. +- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset. + +### Fall back to REST + +**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object: + +```bash +gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body +gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName +gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable +``` + +Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again). + +**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply: + +```bash +gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \ + | jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]' +``` + +Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets. + +**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow. + +**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance). + +### Recovery from secondary rate limit (403 abuse) + +1. Stop all API writes immediately +2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter) +3. Resume with `sleep 3` between each call +4. If 403 persists after 2 min, wait another 2 min before retrying + +Never batch all replies in a tight loop — always space them out. + +## Parallel thread resolution + +When a PR has more than 10 unresolved threads, addressing one commit per thread is slow. Use this strategy instead: + +### Group by file, batch per commit + +1. Sort `ALL_THREADS` by `path` — threads in the same file can share a single commit. +2. Fix all threads in one file → `git commit` → `git push` → reply to **all** those threads with the same SHA → resolve them all. +3. Move to the next file group and repeat. + +This reduces N commits to (number of files touched), which is usually 3–5 instead of 15–30. + +### Posting replies concurrently (for large batches) + +For truly independent thread groups (different files, no shared logic), you can post replies in parallel using background subshells — but always space out API writes: + +```bash +# Post replies to a batch of threads concurrently, 3s apart +( + sleep 3 + gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID1}/replies \ + -f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..." +) & +( + sleep 6 + gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID2}/replies \ + -f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..." +) & +wait # wait for all background replies before resolving +``` + +Then resolve sequentially (GraphQL mutations): +```bash +for THREAD_ID in "$THREAD1" "$THREAD2" "$THREAD3"; do + gh api graphql -f query="mutation { resolveReviewThread(input: {threadId: \"${THREAD_ID}\"}) { thread { isResolved } } }" + sleep 3 +done +``` + +**Always sleep 3s between individual API writes** — GitHub's secondary rate limit (403) triggers on bursts of >20 writes. Increase to `sleep 5` when posting more than 20 replies in a batch. + +## Resolving threads via GraphQL + +Use `resolveReviewThread` **only after** the commit is pushed and the reply is posted: + +```bash +gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREAD_ID"}) { thread { isResolved } } }' +``` + +**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed. + +> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow. + +### Verify actual count before outputting ORCHESTRATOR:DONE + +Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1: + +```bash +# Step 1: get total thread count +gh api graphql -f query=' +{ + repository(owner: "Significant-Gravitas", name: "AutoGPT") { + pullRequest(number: {N}) { + reviewThreads { totalCount } + } + } +}' | jq '.data.repository.pullRequest.reviewThreads.totalCount' + +# Step 2: paginate all pages, count truly unresolved +CURSOR=""; UNRESOLVED=0 +while true; do + AFTER=${CURSOR:+", after: \"$CURSOR\""} + PAGE=$(gh api graphql -f query=" + { + repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") { + pullRequest(number: {N}) { + reviewThreads(first: 100${AFTER}) { + pageInfo { hasNextPage endCursor } + nodes { isResolved } + } + } + } + }") + UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') )) + HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage') + CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor') + [ "$HAS_NEXT" = "false" ] && break +done +echo "Unresolved threads: $UNRESOLVED" +``` + +Only output `ORCHESTRATOR:DONE` after this loop reports 0. diff --git a/.claude/skills/pr-test/SKILL.md b/.claude/skills/pr-test/SKILL.md index b915cc55ab..b368fb7f0d 100644 --- a/.claude/skills/pr-test/SKILL.md +++ b/.claude/skills/pr-test/SKILL.md @@ -5,7 +5,7 @@ user-invocable: true argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)" metadata: author: autogpt-team - version: "2.0.0" + version: "2.1.0" --- # Manual E2E Test @@ -180,6 +180,94 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`: **Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify. +## Step 3.0: Claim the testing lock (coordinate parallel agents) + +Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns. + +### Lock file contract + +Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock` + +Body (one `key=value` per line): +``` +holder= +pid= +started= +heartbeat= +worktree= +branch= +intent= +``` + +### Claim + +```bash +LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock +NOW=$(date -u +%Y-%m-%dT%H:%MZ) +STALE_AFTER_MIN=5 + +if [ -f "$LOCK" ]; then + HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2) + HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0) + AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 )) + if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then + echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming" + cat "$LOCK" | sed 's/^/ stale: /' + else + echo "Another agent holds the lock:"; cat "$LOCK" + echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m." + exit 1 + fi +fi + +cat > "$LOCK" < /tmp/pr-test-heartbeat.pid +``` + +### Release (always — even on failure) + +```bash +kill "$HEARTBEAT_PID" 2>/dev/null +rm -f "$LOCK" /tmp/pr-test-heartbeat.pid +echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \ + >> /Users/majdyz/Code/AutoGPT/.ign.testing.log +``` + +Use a `trap` so release runs even on `exit 1`: +```bash +trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM +``` + +### Shared status log + +`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes: +```bash +echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] " \ + >> /Users/majdyz/Code/AutoGPT/.ign.testing.log +``` + ## Step 3: Environment setup ### 3a. Copy .env files from the root worktree @@ -248,7 +336,87 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke done ``` -### 3e. Build and start +**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind. + +```bash +# Kill stray native app processes from prior runs +pkill -9 -f "python.*backend" 2>/dev/null || true +pkill -9 -f "poetry run app" 2>/dev/null || true +pkill -9 -f "next-server|next dev" 2>/dev/null || true + +# Free app ports (errors per port are ignored — port may simply be unused) +for port in 3000 8006 8001 8002 8005 8008; do + lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true +done +``` + +### 3e-native. Run the app natively (PREFERRED for iterative dev) + +Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild. + +**When to prefer native mode (default for this skill):** +- Iterative dev/debug loops where you're editing backend or frontend code between test runs +- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images +- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds + +**When to prefer docker mode (3e fallback):** +- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images +- Production-parity smoke tests (exact container env, networking, volumes) +- CI-equivalent runs where you need the exact image that'll ship + +**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call). + +**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection. + +**1. Start infra only (one-time per session):** + +```bash +cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build +``` + +This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services. + +**2. Start the backend natively:** + +```bash +cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) & +``` + +`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored. + +**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately. + +```bash +for i in $(seq 1 60); do + if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then + echo "Backend ready" + break + fi + sleep 2 +done +``` + +**4. Start the frontend natively:** + +```bash +cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) & +``` + +**5. Wait for the frontend on :3000:** + +```bash +for i in $(seq 1 60); do + if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then + echo "Frontend ready" + break + fi + sleep 2 +done +``` + +Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation). + +### 3e. Build and start (docker — fallback) ```bash cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20 @@ -310,6 +478,28 @@ TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/... ``` +### 3i. Disable onboarding for test user + +The frontend redirects to `/onboarding` when the `VISIT_COPILOT` step is not in `completedSteps`. +Mark it complete via the backend API so every browser test lands on the real feature UI: + +```bash +ONBOARDING_RESULT=$(curl -s --max-time 30 -X POST \ + "http://localhost:8006/api/onboarding/step?step=VISIT_COPILOT" \ + -H "Authorization: Bearer $TOKEN") +echo "Onboarding bypass: $ONBOARDING_RESULT" + +# Verify it took effect +ONBOARDING_STATUS=$(curl -s --max-time 30 \ + "http://localhost:8006/api/onboarding/completed" \ + -H "Authorization: Bearer $TOKEN" | jq -r '.is_completed') +echo "Onboarding completed: $ONBOARDING_STATUS" +if [ "$ONBOARDING_STATUS" != "true" ]; then + echo "ERROR: onboarding bypass failed — browser tests will hit /onboarding instead of the target feature. Investigate before proceeding." + exit 1 +fi +``` + ## Step 4: Run tests ### Service ports reference @@ -420,6 +610,22 @@ agent-browser --session-name pr-test snapshot | grep "text:" ### Checking logs +**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`. + +```bash +# Backend (all app subprocesses interleaved) +tail -f $BACKEND_DIR/.ign.application.logs + +# Frontend (Next.js dev server) +tail -f $FRONTEND_DIR/.ign.frontend.logs + +# Filter for errors across either log +grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20 +grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20 +``` + +**Docker mode:** + ```bash # Backend REST server docker logs autogpt_platform-rest_server-1 2>&1 | tail -30 @@ -547,6 +753,8 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations **This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.** +**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes. + ```bash # Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely) REPO="Significant-Gravitas/AutoGPT" @@ -584,15 +792,27 @@ TREE_JSON+=']' # Step 2: Create tree, commit, and branch ref TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha') -COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \ - -f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \ - -f tree="$TREE_SHA" \ - --jq '.sha') + +# Resolve parent commit so screenshots are chained, not orphan root commits +PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "") +if [ -n "$PARENT_SHA" ]; then + COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \ + -f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \ + -f tree="$TREE_SHA" \ + -f "parents[]=$PARENT_SHA" \ + --jq '.sha') +else + COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \ + -f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \ + -f tree="$TREE_SHA" \ + --jq '.sha') +fi + gh api "repos/${REPO}/git/refs" \ -f ref="refs/heads/${SCREENSHOTS_BRANCH}" \ -f sha="$COMMIT_SHA" 2>/dev/null \ || gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \ - -X PATCH -f sha="$COMMIT_SHA" -f force=true + -X PATCH -f sha="$COMMIT_SHA" -F force=true ``` Then post the comment with **inline images AND explanations for each screenshot**: @@ -658,6 +878,15 @@ INNEREOF gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" rm -f "$COMMENT_FILE" + +# Verify the posted comment contains inline images — exit 1 if none found +# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list +LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""') +if ! echo "$LAST_COMMENT" | grep -q '!\['; then + echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2 + exit 1 +fi +echo "✓ Inline images verified in posted comment" ``` **The PR comment MUST include:** @@ -667,6 +896,103 @@ rm -f "$COMMENT_FILE" This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch. +## Step 8: Evaluate and post a formal PR review + +After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.** + +### Evaluation criteria + +Re-read the PR description: +```bash +gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO" +``` + +Score the run against each criterion: + +| Criterion | Pass condition | +|-----------|---------------| +| **Coverage** | Every feature/change described in the PR has at least one test scenario | +| **All scenarios pass** | No FAIL rows in the results table | +| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) | +| **Before/after evidence** | Every state-changing API call has before/after values logged | +| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page | +| **No regressions** | Existing core flows (login, agent create/run) still work | + +### Decision logic + +``` +ALL criteria pass → APPROVE +Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps) +Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing) +``` + +### Post the review + +```bash +REVIEW_FILE=$(mktemp) + +# Count results +PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true) +FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true) +TOTAL=$(( PASS_COUNT + FAIL_COUNT )) + +# List any coverage gaps found during evaluation (populate this array as you assess) +# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it") +COVERAGE_GAPS=() +``` + +**If APPROVING** — all criteria met, zero failures, full coverage: + +```bash +cat > "$REVIEW_FILE" < "$REVIEW_FILE" <&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`. diff --git a/.claude/skills/write-frontend-tests/SKILL.md b/.claude/skills/write-frontend-tests/SKILL.md new file mode 100644 index 0000000000..389de2023b --- /dev/null +++ b/.claude/skills/write-frontend-tests/SKILL.md @@ -0,0 +1,225 @@ +--- +name: write-frontend-tests +description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'." +user-invocable: true +args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against." +metadata: + author: autogpt-team + version: "1.0.0" +--- + +# Write Frontend Tests + +Analyze the current branch's frontend changes, plan integration tests, and write them. + +## References + +Before writing any tests, read the testing rules and conventions: + +- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples +- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart +- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers +- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup + +## Step 1: Identify changed frontend files + +```bash +BASE_BRANCH="${ARGUMENTS:-dev}" +cd autogpt_platform/frontend + +# Get changed frontend files (excluding generated, config, and test files) +git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \ + | grep -v '__generated__' \ + | grep -v '__tests__' \ + | grep -v '\.test\.' \ + | grep -v '\.stories\.' \ + | grep -v '\.spec\.' +``` + +Also read the diff to understand what changed: + +```bash +git diff "$BASE_BRANCH"...HEAD --stat -- src/ +git diff "$BASE_BRANCH"...HEAD -- src/ | head -500 +``` + +## Step 2: Categorize changes and find test targets + +For each changed file, determine: + +1. **Is it a page?** (`page.tsx`) — these are the primary test targets +2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic +3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation +4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic + +**Priority order:** + +1. Pages with new/changed data fetching or user interactions +2. Components with complex internal logic (modals, forms, wizards) +3. Shared hooks with standalone business logic when UI-level coverage is impractical +4. Pure helper functions + +Skip: styling-only changes, type-only changes, config changes. + +## Step 3: Check for existing tests + +For each test target, check if tests already exist: + +```bash +# For a page at src/app/(platform)/library/page.tsx +ls src/app/\(platform\)/library/__tests__/ 2>/dev/null + +# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx +ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null +``` + +Note which targets have no tests (need new files) vs which have tests that need updating. + +## Step 4: Identify API endpoints used + +For each test target, find which API hooks are used: + +```bash +# Find generated API hook imports in the changed files +grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/ +grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/ +``` + +For each API hook found, locate the corresponding MSW handler: + +```bash +# If the page uses useGetV2ListLibraryAgents, find its MSW handlers +grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts +``` + +List every MSW handler you will need (200 for happy path, 4xx for error paths). + +## Step 5: Write the test plan + +Before writing code, output a plan as a numbered list: + +``` +Test plan for [branch name]: + +1. src/app/(platform)/library/__tests__/main.test.tsx (NEW) + - Renders page with agent list (MSW 200) + - Shows loading state + - Shows error state (MSW 422) + - Handles empty agent list + +2. src/app/(platform)/library/__tests__/search.test.tsx (NEW) + - Filters agents by search query + - Shows no results message + - Clears search + +3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE) + - Add test for new "duplicate" action +``` + +Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan. + +## Step 6: Write the tests + +For each test file in the plan, follow these conventions: + +### File structure + +```tsx +import { render, screen, waitFor } from "@/tests/integrations/test-utils"; +import { server } from "@/mocks/mock-server"; +// Import MSW handlers for endpoints the page uses +import { + getGetV2ListLibraryAgentsMockHandler200, + getGetV2ListLibraryAgentsMockHandler422, +} from "@/app/api/__generated__/endpoints/library/library.msw"; +// Import the component under test +import LibraryPage from "../page"; + +describe("LibraryPage", () => { + test("renders agent list from API", async () => { + server.use(getGetV2ListLibraryAgentsMockHandler200()); + + render(); + + expect(await screen.findByText(/my agents/i)).toBeDefined(); + }); + + test("shows error state on API failure", async () => { + server.use(getGetV2ListLibraryAgentsMockHandler422()); + + render(); + + expect(await screen.findByText(/error/i)).toBeDefined(); + }); +}); +``` + +### Rules + +- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly) +- Use `server.use()` to set up MSW handlers BEFORE rendering +- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*` +- Use `getBy*` only for elements that are immediately present in the DOM +- Use `screen` queries — do NOT destructure from `render()` +- Use `waitFor` when asserting side effects or state changes after interactions +- Import `fireEvent` or `userEvent` from the test-utils for interactions +- Do NOT mock internal hooks or functions — mock at the API boundary via MSW +- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects +- Do NOT use `act()` manually — `render` and `fireEvent` handle it +- Keep tests focused: one behavior per test +- Use descriptive test names that read like sentences + +### Test location + +``` +# For pages: __tests__/ next to page.tsx +src/app/(platform)/library/__tests__/main.test.tsx + +# For complex standalone components: __tests__/ inside component folder +src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx + +# For pure helpers: co-located .test.ts +src/app/(platform)/library/helpers.test.ts +``` + +### Custom MSW overrides + +When the auto-generated faker data is not enough, override with specific data: + +```tsx +import { http, HttpResponse } from "msw"; + +server.use( + http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => { + return HttpResponse.json({ + agents: [{ id: "1", name: "Test Agent", description: "A test agent" }], + pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 }, + }); + }), +); +``` + +Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`. + +## Step 7: Run and verify + +After writing all tests: + +```bash +cd autogpt_platform/frontend +pnpm test:unit --reporter=verbose +``` + +If tests fail: + +1. Read the error output carefully +2. Fix the test (not the source code, unless there is a genuine bug) +3. Re-run until all pass + +Then run the full checks: + +```bash +pnpm format +pnpm lint +pnpm types +``` diff --git a/.github/workflows/platform-fullstack-ci.yml b/.github/workflows/platform-fullstack-ci.yml index fc772171b1..605c13c38b 100644 --- a/.github/workflows/platform-fullstack-ci.yml +++ b/.github/workflows/platform-fullstack-ci.yml @@ -160,6 +160,7 @@ jobs: run: | cp ../backend/.env.default ../backend/.env echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env + echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env env: # Used by E2E test data script to generate embeddings for approved store agents OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -179,21 +180,30 @@ jobs: pip install pyyaml # Resolve extends and generate a flat compose file that bake can understand + export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST docker compose -f docker-compose.yml config > docker-compose.resolved.yml + # Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose + # (docker compose config on some versions drops this arg) + if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then + echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)" + sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml + fi + # Add cache configuration to the resolved compose file python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ --source docker-compose.resolved.yml \ --cache-from "type=gha" \ --cache-to "type=gha,mode=max" \ --backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \ - --frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \ + --frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \ --git-ref "${{ github.ref }}" # Build with bake using the resolved compose file (now includes cache config) docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load env: NEXT_PUBLIC_PW_TEST: true + NEXT_PUBLIC_SOURCEMAPS: true - name: Set up tests - Cache E2E test data id: e2e-data-cache @@ -279,16 +289,38 @@ jobs: cache: "pnpm" cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml + - name: Set up tests - Cache Playwright browsers + uses: actions/cache@v5 + with: + path: ~/.cache/ms-playwright + key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} + restore-keys: | + playwright-${{ runner.os }}- + + - name: Copy source maps from Docker for E2E coverage + run: | + FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend) + docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage + - name: Set up tests - Install dependencies run: pnpm install --frozen-lockfile - name: Set up tests - Install browser 'chromium' run: pnpm playwright install --with-deps chromium - - name: Run Playwright tests - run: pnpm test:no-build + - name: Run Playwright E2E suite + run: pnpm test:e2e:no-build continue-on-error: false + - name: Upload E2E coverage to Codecov + if: ${{ !cancelled() }} + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + flags: platform-frontend-e2e + files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml + disable_search: true + - name: Upload Playwright report if: always() uses: actions/upload-artifact@v4 diff --git a/.gitignore b/.gitignore index 9a9db80e40..53df57dc70 100644 --- a/.gitignore +++ b/.gitignore @@ -187,9 +187,12 @@ autogpt_platform/backend/settings.py .claude/settings.local.json CLAUDE.local.md /autogpt_platform/backend/logs +/autogpt_platform/backend/poetry.toml # Test database test.db .next # Implementation plans (generated by AI agents) plans/ +.claude/worktrees/ +test-results/ diff --git a/.gitleaks.toml b/.gitleaks.toml new file mode 100644 index 0000000000..75867a7f50 --- /dev/null +++ b/.gitleaks.toml @@ -0,0 +1,36 @@ +title = "AutoGPT Gitleaks Config" + +[extend] +useDefault = true + +[allowlist] +description = "Global allowlist" +paths = [ + # Template/example env files (no real secrets) + '''\.env\.(default|example|template)$''', + # Lock files + '''pnpm-lock\.yaml$''', + '''poetry\.lock$''', + # Secrets baseline + '''\.secrets\.baseline$''', + # Build artifacts and caches (should not be committed) + '''__pycache__/''', + '''classic/frontend/build/''', + # Docker dev setup (local dev JWTs/keys only) + '''autogpt_platform/db/docker/''', + # Load test configs (dev JWTs) + '''load-tests/configs/''', + # Test files with fake/fixture keys (_test.py, test_*.py, conftest.py) + '''(_test|test_.*|conftest)\.py$''', + # Documentation (only contains placeholder keys in curl/API examples) + '''docs/.*\.md$''', + # Firebase config (public API keys by design) + '''google-services\.json$''', + '''classic/frontend/(lib|web)/''', +] +# CI test-only encryption key (marked DO NOT USE IN PRODUCTION) +regexes = [ + '''dvziYgz0KSK8FENhju0ZYi8''', + # LLM model name enum values falsely flagged as API keys + '''Llama-\d.*Instruct''', +] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dc1951992..b5527825ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,9 +23,15 @@ repos: - id: detect-secrets name: Detect secrets description: Detects high entropy strings that are likely to be passwords. + args: ["--baseline", ".secrets.baseline"] files: ^autogpt_platform/ - exclude: pnpm-lock\.yaml$ - stages: [pre-push] + exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$ + + - repo: https://github.com/gitleaks/gitleaks + rev: v8.24.3 + hooks: + - id: gitleaks + name: Detect secrets (gitleaks) - repo: local # For proper type checking, all dependencies need to be up-to-date. diff --git a/.secrets.baseline b/.secrets.baseline new file mode 100644 index 0000000000..c2b1f3430a --- /dev/null +++ b/.secrets.baseline @@ -0,0 +1,471 @@ +{ + "version": "1.5.0", + "plugins_used": [ + { + "name": "ArtifactoryDetector" + }, + { + "name": "AWSKeyDetector" + }, + { + "name": "AzureStorageKeyDetector" + }, + { + "name": "Base64HighEntropyString", + "limit": 4.5 + }, + { + "name": "BasicAuthDetector" + }, + { + "name": "CloudantDetector" + }, + { + "name": "DiscordBotTokenDetector" + }, + { + "name": "GitHubTokenDetector" + }, + { + "name": "GitLabTokenDetector" + }, + { + "name": "HexHighEntropyString", + "limit": 3.0 + }, + { + "name": "IbmCloudIamDetector" + }, + { + "name": "IbmCosHmacDetector" + }, + { + "name": "IPPublicDetector" + }, + { + "name": "JwtTokenDetector" + }, + { + "name": "KeywordDetector", + "keyword_exclude": "" + }, + { + "name": "MailchimpDetector" + }, + { + "name": "NpmDetector" + }, + { + "name": "OpenAIDetector" + }, + { + "name": "PrivateKeyDetector" + }, + { + "name": "PypiTokenDetector" + }, + { + "name": "SendGridDetector" + }, + { + "name": "SlackDetector" + }, + { + "name": "SoftlayerDetector" + }, + { + "name": "SquareOAuthDetector" + }, + { + "name": "StripeDetector" + }, + { + "name": "TelegramBotTokenDetector" + }, + { + "name": "TwilioKeyDetector" + } + ], + "filters_used": [ + { + "path": "detect_secrets.filters.allowlist.is_line_allowlisted" + }, + { + "path": "detect_secrets.filters.common.is_baseline_file", + "filename": ".secrets.baseline" + }, + { + "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", + "min_level": 2 + }, + { + "path": "detect_secrets.filters.heuristic.is_indirect_reference" + }, + { + "path": "detect_secrets.filters.heuristic.is_likely_id_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_lock_file" + }, + { + "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_potential_uuid" + }, + { + "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign" + }, + { + "path": "detect_secrets.filters.heuristic.is_sequential_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_swagger_file" + }, + { + "path": "detect_secrets.filters.heuristic.is_templated_secret" + }, + { + "path": "detect_secrets.filters.regex.should_exclude_file", + "pattern": [ + "\\.env$", + "pnpm-lock\\.yaml$", + "\\.env\\.(default|example|template)$", + "__pycache__", + "_test\\.py$", + "test_.*\\.py$", + "conftest\\.py$", + "poetry\\.lock$", + "node_modules" + ] + } + ], + "results": { + "autogpt_platform/backend/backend/api/external/v1/integrations.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py", + "hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155", + "is_verified": false, + "line_number": 289 + } + ], + "autogpt_platform/backend/backend/blocks/airtable/_config.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py", + "hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26", + "is_verified": false, + "line_number": 29 + } + ], + "autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py", + "hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8", + "is_verified": false, + "line_number": 12 + } + ], + "autogpt_platform/backend/backend/blocks/github/checks.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/checks.py", + "hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238", + "is_verified": false, + "line_number": 108 + } + ], + "autogpt_platform/backend/backend/blocks/github/ci.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/ci.py", + "hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa", + "is_verified": false, + "line_number": 123 + } + ], + "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json", + "hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663", + "is_verified": false, + "line_number": 42 + }, + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json", + "hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e", + "is_verified": false, + "line_number": 193 + }, + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json", + "hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42", + "is_verified": false, + "line_number": 344 + }, + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json", + "hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5", + "is_verified": false, + "line_number": 534 + } + ], + "autogpt_platform/backend/backend/blocks/github/statuses.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/github/statuses.py", + "hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238", + "is_verified": false, + "line_number": 85 + } + ], + "autogpt_platform/backend/backend/blocks/google/docs.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/google/docs.py", + "hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4", + "is_verified": false, + "line_number": 203 + } + ], + "autogpt_platform/backend/backend/blocks/google/sheets.py": [ + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/google/sheets.py", + "hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b", + "is_verified": false, + "line_number": 57 + } + ], + "autogpt_platform/backend/backend/blocks/linear/_config.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/blocks/linear/_config.py", + "hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb", + "is_verified": false, + "line_number": 53 + } + ], + "autogpt_platform/backend/backend/blocks/medium.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/medium.py", + "hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c", + "is_verified": false, + "line_number": 131 + } + ], + "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py", + "hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0", + "is_verified": false, + "line_number": 55 + } + ], + "autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [ + { + "type": "Hex High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py", + "hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9", + "is_verified": false, + "line_number": 100 + } + ], + "autogpt_platform/backend/backend/blocks/talking_head.py": [ + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/backend/backend/blocks/talking_head.py", + "hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799", + "is_verified": false, + "line_number": 113 + } + ], + "autogpt_platform/backend/backend/blocks/wordpress/_config.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py", + "hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb", + "is_verified": false, + "line_number": 17 + } + ], + "autogpt_platform/backend/backend/util/cache.py": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/backend/backend/util/cache.py", + "hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b", + "is_verified": false, + "line_number": 449 + } + ], + "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts", + "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", + "is_verified": false, + "line_number": 6 + } + ], + "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json", + "hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d", + "is_verified": false, + "line_number": 5 + } + ], + "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json", + "hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a", + "is_verified": false, + "line_number": 5 + } + ], + "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts", + "hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679", + "is_verified": false, + "line_number": 6 + }, + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts", + "hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380", + "is_verified": false, + "line_number": 8 + } + ], + "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts", + "hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679", + "is_verified": false, + "line_number": 5 + }, + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts", + "hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380", + "is_verified": false, + "line_number": 7 + } + ], + "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx", + "hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679", + "is_verified": false, + "line_number": 192 + }, + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx", + "hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6", + "is_verified": false, + "line_number": 193 + } + ], + "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts", + "hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd", + "is_verified": false, + "line_number": 102 + }, + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts", + "hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d", + "is_verified": false, + "line_number": 103 + } + ], + "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [ + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025", + "is_verified": false, + "line_number": 73 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c", + "is_verified": false, + "line_number": 75 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340", + "is_verified": false, + "line_number": 77 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b", + "is_verified": false, + "line_number": 79 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a", + "is_verified": false, + "line_number": 81 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64", + "is_verified": false, + "line_number": 83 + }, + { + "type": "Base64 High Entropy String", + "filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts", + "hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79", + "is_verified": false, + "line_number": 85 + } + ], + "autogpt_platform/frontend/src/lib/constants.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/lib/constants.ts", + "hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d", + "is_verified": false, + "line_number": 13 + } + ], + "autogpt_platform/frontend/src/tests/credentials/index.ts": [ + { + "type": "Secret Keyword", + "filename": "autogpt_platform/frontend/src/tests/credentials/index.ts", + "hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37", + "is_verified": false, + "line_number": 4 + } + ] + }, + "generated_at": "2026-04-09T14:20:23Z" +} diff --git a/AGENTS.md b/AGENTS.md index f88741ae3a..d0b325167c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: - Regenerate with `pnpm generate:api` - Pattern: `use{Method}{Version}{OperationName}` 4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only -5. **Testing**: Add Storybook stories for new components, Playwright for E2E +5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md` 6. **Code conventions**: Function declarations (not arrow functions) for components/handlers - Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component @@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: ## Testing - Backend: `poetry run test` (runs pytest with a docker based postgres + prisma). -- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips. +- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach). +- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests. +- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy. Always run the relevant linters and tests before committing. Use conventional commit messages for all commits (e.g. `feat(backend): add API`). diff --git a/autogpt_platform/.gitignore b/autogpt_platform/.gitignore index 3e31a9970e..bc70dc96bc 100644 --- a/autogpt_platform/.gitignore +++ b/autogpt_platform/.gitignore @@ -1,3 +1,6 @@ *.ignore.* *.ign.* .application.logs + +# Claude Code local settings only — the rest of .claude/ is shared (skills etc.) +.claude/settings.local.json diff --git a/autogpt_platform/analytics/queries/platform_cost_log.sql b/autogpt_platform/analytics/queries/platform_cost_log.sql new file mode 100644 index 0000000000..b3e33d7515 --- /dev/null +++ b/autogpt_platform/analytics/queries/platform_cost_log.sql @@ -0,0 +1,100 @@ +-- ============================================================= +-- View: analytics.platform_cost_log +-- Looker source alias: ds115 | Charts: 0 +-- ============================================================= +-- DESCRIPTION +-- One row per platform cost log entry (last 90 days). +-- Tracks real API spend at the call level: provider, model, +-- token counts (including Anthropic cache tokens), cost in +-- microdollars, and the block/execution that incurred the cost. +-- Joins the User table to provide email for per-user breakdowns. +-- +-- SOURCE TABLES +-- platform.PlatformCostLog — Per-call cost records +-- platform.User — User email +-- +-- OUTPUT COLUMNS +-- id TEXT Log entry UUID +-- createdAt TIMESTAMPTZ When the cost was recorded +-- userId TEXT User who incurred the cost (nullable) +-- email TEXT User email (nullable) +-- graphExecId TEXT Graph execution UUID (nullable) +-- nodeExecId TEXT Node execution UUID (nullable) +-- blockName TEXT Block that made the API call (nullable) +-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic') +-- model TEXT Model name (nullable) +-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc. +-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD) +-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000) +-- inputTokens INT Prompt/input tokens (nullable) +-- outputTokens INT Completion/output tokens (nullable) +-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable) +-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable) +-- totalTokens INT inputTokens + outputTokens (nullable if either is null) +-- duration FLOAT API call duration in seconds (nullable) +-- +-- WINDOW +-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days) +-- +-- EXAMPLE QUERIES +-- -- Total spend by provider (last 90 days) +-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY total_usd DESC; +-- +-- -- Spend by model +-- SELECT provider, model, SUM("costUsd") AS total_usd, +-- SUM("inputTokens") AS input_tokens, +-- SUM("outputTokens") AS output_tokens +-- FROM analytics.platform_cost_log +-- WHERE model IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC; +-- +-- -- Top 20 users by spend +-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- WHERE "userId" IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20; +-- +-- -- Daily spend trend +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("costUsd") AS daily_usd, +-- COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY 1; +-- +-- -- Cache hit rate for Anthropic (cache reads vs total reads) +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("cacheReadTokens")::float / +-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate +-- FROM analytics.platform_cost_log +-- WHERE provider = 'anthropic' +-- GROUP BY 1 ORDER BY 1; +-- ============================================================= + +SELECT + p."id" AS id, + p."createdAt" AS createdAt, + p."userId" AS userId, + u."email" AS email, + p."graphExecId" AS graphExecId, + p."nodeExecId" AS nodeExecId, + p."blockName" AS blockName, + p."provider" AS provider, + p."model" AS model, + p."trackingType" AS trackingType, + p."costMicrodollars" AS costMicrodollars, + p."costMicrodollars"::float / 1000000.0 AS costUsd, + p."inputTokens" AS inputTokens, + p."outputTokens" AS outputTokens, + p."cacheReadTokens" AS cacheReadTokens, + p."cacheCreationTokens" AS cacheCreationTokens, + CASE + WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL + THEN p."inputTokens" + p."outputTokens" + ELSE NULL + END AS totalTokens, + p."duration" AS duration +FROM platform."PlatformCostLog" p +LEFT JOIN platform."User" u ON u."id" = p."userId" +WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days' diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index 8ba3f758d9..67444c2e36 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -58,6 +58,17 @@ V0_API_KEY= OPEN_ROUTER_API_KEY= NVIDIA_API_KEY= +# Graphiti Temporal Knowledge Graph Memory +# Rollout controlled by LaunchDarkly flag "graphiti-memory" +# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY. +# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY. +GRAPHITI_FALKORDB_HOST=localhost +GRAPHITI_FALKORDB_PORT=6380 +GRAPHITI_FALKORDB_PASSWORD= +GRAPHITI_LLM_MODEL=gpt-4.1-mini +GRAPHITI_EMBEDDER_MODEL=text-embedding-3-small +GRAPHITI_SEMAPHORE_LIMIT=5 + # Langfuse Prompt Management # Used for managing the CoPilot system prompt externally # Get credentials from https://cloud.langfuse.com or your self-hosted instance @@ -168,6 +179,9 @@ MEM0_API_KEY= OPENWEATHERMAP_API_KEY= GOOGLE_MAPS_API_KEY= +# Platform Bot Linking +PLATFORM_LINK_BASE_URL=http://localhost:3000/link + # Communication Services DISCORD_BOT_TOKEN= MEDIUM_API_KEY= diff --git a/autogpt_platform/backend/agents/calculator-agent.json b/autogpt_platform/backend/agents/calculator-agent.json new file mode 100644 index 0000000000..9851b1496b --- /dev/null +++ b/autogpt_platform/backend/agents/calculator-agent.json @@ -0,0 +1,166 @@ +{ + "id": "858e2226-e047-4d19-a832-3be4a134d155", + "version": 2, + "is_active": true, + "name": "Calculator agent", + "description": "", + "instructions": null, + "recommended_schedule_cron": null, + "forked_from_id": null, + "forked_from_version": null, + "user_id": "", + "created_at": "2026-04-13T03:45:11.241Z", + "nodes": [ + { + "id": "6762da5d-6915-4836-a431-6dcd7d36a54a", + "block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b", + "input_default": { + "name": "Input", + "secret": false, + "advanced": false + }, + "metadata": { + "position": { + "x": -188.2244873046875, + "y": 95 + } + }, + "input_links": [], + "output_links": [ + { + "id": "432c7caa-49b9-4b70-bd21-2fa33a569601", + "source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a", + "sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "source_name": "result", + "sink_name": "a", + "is_static": true + } + ], + "graph_id": "858e2226-e047-4d19-a832-3be4a134d155", + "graph_version": 2, + "webhook_id": null + }, + { + "id": "65429c9e-a0c6-4032-a421-6899c394fa74", + "block_id": "363ae599-353e-4804-937e-b2ee3cef3da4", + "input_default": { + "name": "Output", + "secret": false, + "advanced": false, + "escape_html": false + }, + "metadata": { + "position": { + "x": 825.198974609375, + "y": 123.75 + } + }, + "input_links": [ + { + "id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3", + "source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74", + "source_name": "result", + "sink_name": "value", + "is_static": false + } + ], + "output_links": [], + "graph_id": "858e2226-e047-4d19-a832-3be4a134d155", + "graph_version": 2, + "webhook_id": null + }, + { + "id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79", + "input_default": { + "b": 34, + "operation": "Add", + "round_result": false + }, + "metadata": { + "position": { + "x": 323.0255126953125, + "y": 121.25 + } + }, + "input_links": [ + { + "id": "432c7caa-49b9-4b70-bd21-2fa33a569601", + "source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a", + "sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "source_name": "result", + "sink_name": "a", + "is_static": true + } + ], + "output_links": [ + { + "id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3", + "source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74", + "source_name": "result", + "sink_name": "value", + "is_static": false + } + ], + "graph_id": "858e2226-e047-4d19-a832-3be4a134d155", + "graph_version": 2, + "webhook_id": null + } + ], + "links": [ + { + "id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3", + "source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74", + "source_name": "result", + "sink_name": "value", + "is_static": false + }, + { + "id": "432c7caa-49b9-4b70-bd21-2fa33a569601", + "source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a", + "sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690", + "source_name": "result", + "sink_name": "a", + "is_static": true + } + ], + "sub_graphs": [], + "input_schema": { + "type": "object", + "properties": { + "Input": { + "advanced": false, + "secret": false, + "title": "Input" + } + }, + "required": [ + "Input" + ] + }, + "output_schema": { + "type": "object", + "properties": { + "Output": { + "advanced": false, + "secret": false, + "title": "Output" + } + }, + "required": [ + "Output" + ] + }, + "has_external_trigger": false, + "has_human_in_the_loop": false, + "has_sensitive_action": false, + "trigger_setup_info": null, + "credentials_input_schema": { + "type": "object", + "properties": {}, + "required": [] + } +} \ No newline at end of file diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py new file mode 100644 index 0000000000..4cb8ff0729 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py @@ -0,0 +1,932 @@ +import asyncio +import logging +from typing import List + +from autogpt_libs.auth import requires_admin_user +from autogpt_libs.auth.models import User as AuthUser +from fastapi import APIRouter, HTTPException, Security +from prisma.enums import AgentExecutionStatus +from pydantic import BaseModel + +from backend.api.features.admin.model import ( + AgentDiagnosticsResponse, + ExecutionDiagnosticsResponse, +) +from backend.data.diagnostics import ( + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, + cleanup_all_stuck_queued_executions, + cleanup_orphaned_executions_bulk, + cleanup_orphaned_schedules_bulk, + get_agent_diagnostics, + get_all_orphaned_execution_ids, + get_all_schedules_details, + get_all_stuck_queued_execution_ids, + get_execution_diagnostics, + get_failed_executions_count, + get_failed_executions_details, + get_invalid_executions_details, + get_long_running_executions_details, + get_orphaned_executions_details, + get_orphaned_schedules_details, + get_running_executions_details, + get_schedule_health_metrics, + get_stuck_queued_executions_details, + stop_all_long_running_executions, +) +from backend.data.execution import get_graph_executions +from backend.executor.utils import add_graph_execution, stop_graph_execution + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/admin", + tags=["diagnostics", "admin"], + dependencies=[Security(requires_admin_user)], +) + + +class RunningExecutionsListResponse(BaseModel): + """Response model for list of running executions""" + + executions: List[RunningExecutionDetail] + total: int + + +class FailedExecutionsListResponse(BaseModel): + """Response model for list of failed executions""" + + executions: List[FailedExecutionDetail] + total: int + + +class StopExecutionRequest(BaseModel): + """Request model for stopping a single execution""" + + execution_id: str + + +class StopExecutionsRequest(BaseModel): + """Request model for stopping multiple executions""" + + execution_ids: List[str] + + +class StopExecutionResponse(BaseModel): + """Response model for stop execution operations""" + + success: bool + stopped_count: int = 0 + message: str + + +class RequeueExecutionResponse(BaseModel): + """Response model for requeue execution operations""" + + success: bool + requeued_count: int = 0 + message: str + + +@router.get( + "/diagnostics/executions", + response_model=ExecutionDiagnosticsResponse, + summary="Get Execution Diagnostics", +) +async def get_execution_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about execution status. + + Returns all execution metrics including: + - Current state (running, queued) + - Orphaned executions (>24h old, likely not in executor) + - Failure metrics (1h, 24h, rate) + - Long-running detection (stuck >1h, >24h) + - Stuck queued detection + - Throughput metrics (completions/hour) + - RabbitMQ queue depths + """ + logger.info("Getting execution diagnostics") + + diagnostics = await get_execution_diagnostics() + + response = ExecutionDiagnosticsResponse( + running_executions=diagnostics.running_count, + queued_executions_db=diagnostics.queued_db_count, + queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth, + cancel_queue_depth=diagnostics.cancel_queue_depth, + orphaned_running=diagnostics.orphaned_running, + orphaned_queued=diagnostics.orphaned_queued, + failed_count_1h=diagnostics.failed_count_1h, + failed_count_24h=diagnostics.failed_count_24h, + failure_rate_24h=diagnostics.failure_rate_24h, + stuck_running_24h=diagnostics.stuck_running_24h, + stuck_running_1h=diagnostics.stuck_running_1h, + oldest_running_hours=diagnostics.oldest_running_hours, + stuck_queued_1h=diagnostics.stuck_queued_1h, + queued_never_started=diagnostics.queued_never_started, + invalid_queued_with_start=diagnostics.invalid_queued_with_start, + invalid_running_without_start=diagnostics.invalid_running_without_start, + completed_1h=diagnostics.completed_1h, + completed_24h=diagnostics.completed_24h, + throughput_per_hour=diagnostics.throughput_per_hour, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Execution diagnostics: running={diagnostics.running_count}, " + f"queued_db={diagnostics.queued_db_count}, " + f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, " + f"failed_24h={diagnostics.failed_count_24h}" + ) + + return response + + +@router.get( + "/diagnostics/agents", + response_model=AgentDiagnosticsResponse, + summary="Get Agent Diagnostics", +) +async def get_agent_diagnostics_endpoint(): + """ + Get diagnostic information about agents. + + Returns: + - agents_with_active_executions: Number of unique agents with running/queued executions + - timestamp: Current timestamp + """ + logger.info("Getting agent diagnostics") + + diagnostics = await get_agent_diagnostics() + + response = AgentDiagnosticsResponse( + agents_with_active_executions=diagnostics.agents_with_active_executions, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}" + ) + + return response + + +@router.get( + "/diagnostics/executions/running", + response_model=RunningExecutionsListResponse, + summary="List Running Executions", +) +async def list_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of running and queued executions (recent, likely active). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of running executions with details + """ + logger.info(f"Listing running executions (limit={limit}, offset={offset})") + + executions = await get_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.running_count + diagnostics.queued_db_count + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/orphaned", + response_model=RunningExecutionsListResponse, + summary="List Orphaned Executions", +) +async def list_orphaned_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of orphaned executions (>24h old, likely not in executor). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of orphaned executions with details + """ + logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})") + + executions = await get_orphaned_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.orphaned_running + diagnostics.orphaned_queued + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/failed", + response_model=FailedExecutionsListResponse, + summary="List Failed Executions", +) +async def list_failed_executions( + limit: int = 100, + offset: int = 0, + hours: int = 24, +): + """ + Get detailed list of failed executions. + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + hours: Number of hours to look back (default 24) + + Returns: + List of failed executions with error details + """ + logger.info( + f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})" + ) + + executions = await get_failed_executions_details( + limit=limit, offset=offset, hours=hours + ) + + # Get total count for pagination + # Always count actual total for given hours parameter + total = await get_failed_executions_count(hours=hours) + + return FailedExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/long-running", + response_model=RunningExecutionsListResponse, + summary="List Long-Running Executions", +) +async def list_long_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of long-running executions (RUNNING status >24h). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of long-running executions with details + """ + logger.info(f"Listing long-running executions (limit={limit}, offset={offset})") + + executions = await get_long_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_running_24h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/stuck-queued", + response_model=RunningExecutionsListResponse, + summary="List Stuck Queued Executions", +) +async def list_stuck_queued_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of stuck queued executions (QUEUED >1h, never started). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of stuck queued executions with details + """ + logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})") + + executions = await get_stuck_queued_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_queued_1h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/invalid", + response_model=RunningExecutionsListResponse, + summary="List Invalid Executions", +) +async def list_invalid_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of executions in invalid states (READ-ONLY). + + Invalid states indicate data corruption and require manual investigation: + - QUEUED but has startedAt (impossible - can't start while queued) + - RUNNING but no startedAt (impossible - can't run without starting) + + ⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation. + + Each invalid execution likely has a different root cause (crashes, race conditions, + DB corruption). Investigate the execution history and logs to determine appropriate + action (manual cleanup, status fix, or leave as-is if system recovered). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of invalid state executions with details + """ + logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})") + + executions = await get_invalid_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = ( + diagnostics.invalid_queued_with_start + + diagnostics.invalid_running_without_start + ) + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.post( + "/diagnostics/executions/requeue", + response_model=RequeueExecutionResponse, + summary="Requeue Stuck Execution", +) +async def requeue_single_execution( + request: StopExecutionRequest, # Reuse same request model (has execution_id) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue a stuck QUEUED execution (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains execution_id to requeue + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}") + + # Get the execution (validation - must be QUEUED) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + raise HTTPException( + status_code=404, + detail="Execution not found or not in QUEUED status", + ) + + execution = executions[0] + + # Use add_graph_execution in requeue mode + await add_graph_execution( + graph_id=execution.graph_id, + user_id=execution.user_id, + graph_version=execution.graph_version, + graph_exec_id=request.execution_id, # Requeue existing execution + ) + + return RequeueExecutionResponse( + success=True, + requeued_count=1, + message="Execution requeued successfully", + ) + + +@router.post( + "/diagnostics/executions/requeue-bulk", + response_model=RequeueExecutionResponse, + summary="Requeue Multiple Stuck Executions", +) +async def requeue_multiple_executions( + request: StopExecutionsRequest, # Reuse same request model (has execution_ids) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue multiple stuck QUEUED executions (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains list of execution_ids to requeue + + Returns: + Number of executions requeued and success message + """ + logger.info( + f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions" + ) + + # Get executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=request.execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + return RequeueExecutionResponse( + success=False, + requeued_count=0, + message="No QUEUED executions found to requeue", + ) + + # Requeue all executions in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/stop", + response_model=StopExecutionResponse, + summary="Stop Single Execution", +) +async def stop_single_execution( + request: StopExecutionRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop a single execution (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains execution_id to stop + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}") + + # Get the execution to find its owner user_id (required by stop_graph_execution) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + ) + + if not executions: + raise HTTPException(status_code=404, detail="Execution not found") + + execution = executions[0] + + # Use robust stop_graph_execution (cascades to children, waits for termination) + await stop_graph_execution( + user_id=execution.user_id, + graph_exec_id=request.execution_id, + wait_timeout=15.0, + cascade=True, + ) + + return StopExecutionResponse( + success=True, + stopped_count=1, + message="Execution stopped successfully", + ) + + +@router.post( + "/diagnostics/executions/stop-bulk", + response_model=StopExecutionResponse, + summary="Stop Multiple Executions", +) +async def stop_multiple_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop multiple active executions (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains list of execution_ids to stop + + Returns: + Number of executions stopped and success message + """ + + logger.info( + f"Admin {user.user_id} stopping {len(request.execution_ids)} executions" + ) + + # Get executions by ID list + executions = await get_graph_executions( + execution_ids=request.execution_ids, + ) + + if not executions: + return StopExecutionResponse( + success=False, + stopped_count=0, + message="No executions found", + ) + + # Stop all executions in parallel using robust stop_graph_execution + async def stop_one(exec) -> bool: + try: + await stop_graph_execution( + user_id=exec.user_id, + graph_exec_id=exec.id, + wait_timeout=15.0, + cascade=True, + ) + return True + except Exception as e: + logger.error(f"Failed to stop execution {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[stop_one(exec) for exec in executions], return_exceptions=False + ) + + stopped_count = sum(1 for success in results if success) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup Orphaned Executions", +) +async def cleanup_orphaned_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned executions by directly updating DB status (admin only). + For executions in DB but not actually running in executor (old/stale records). + + Args: + request: Contains list of execution_ids to cleanup + + Returns: + Number of executions cleaned up and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions" + ) + + cleaned_count = await cleanup_orphaned_executions_bulk( + request.execution_ids, user.user_id + ) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions", + ) + + +# ============================================================================ +# SCHEDULE DIAGNOSTICS ENDPOINTS +# ============================================================================ + + +class SchedulesListResponse(BaseModel): + """Response model for list of schedules""" + + schedules: List[ScheduleDetail] + total: int + + +class OrphanedSchedulesListResponse(BaseModel): + """Response model for list of orphaned schedules""" + + schedules: List[OrphanedScheduleDetail] + total: int + + +class ScheduleCleanupRequest(BaseModel): + """Request model for cleaning up schedules""" + + schedule_ids: List[str] + + +class ScheduleCleanupResponse(BaseModel): + """Response model for schedule cleanup operations""" + + success: bool + deleted_count: int = 0 + message: str + + +@router.get( + "/diagnostics/schedules", + response_model=ScheduleHealthMetrics, + summary="Get Schedule Diagnostics", +) +async def get_schedule_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about schedule health. + + Returns schedule metrics including: + - Total schedules (user vs system) + - Orphaned schedules by category + - Upcoming executions + """ + logger.info("Getting schedule diagnostics") + + diagnostics = await get_schedule_health_metrics() + + logger.info( + f"Schedule diagnostics: total={diagnostics.total_schedules}, " + f"user={diagnostics.user_schedules}, " + f"orphaned={diagnostics.total_orphaned}" + ) + + return diagnostics + + +@router.get( + "/diagnostics/schedules/all", + response_model=SchedulesListResponse, + summary="List All User Schedules", +) +async def list_all_schedules( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of all user schedules (excludes system monitoring jobs). + + Args: + limit: Maximum number of schedules to return (default 100) + offset: Number of schedules to skip (default 0) + + Returns: + List of schedules with details + """ + logger.info(f"Listing all schedules (limit={limit}, offset={offset})") + + schedules = await get_all_schedules_details(limit=limit, offset=offset) + + # Get total count + diagnostics = await get_schedule_health_metrics() + total = diagnostics.user_schedules + + return SchedulesListResponse(schedules=schedules, total=total) + + +@router.get( + "/diagnostics/schedules/orphaned", + response_model=OrphanedSchedulesListResponse, + summary="List Orphaned Schedules", +) +async def list_orphaned_schedules(): + """ + Get detailed list of orphaned schedules with orphan reasons. + + Returns: + List of orphaned schedules categorized by orphan type + """ + logger.info("Listing orphaned schedules") + + schedules = await get_orphaned_schedules_details() + + return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules)) + + +@router.post( + "/diagnostics/schedules/cleanup-orphaned", + response_model=ScheduleCleanupResponse, + summary="Cleanup Orphaned Schedules", +) +async def cleanup_orphaned_schedules( + request: ScheduleCleanupRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned schedules by deleting from scheduler (admin only). + + Args: + request: Contains list of schedule_ids to delete + + Returns: + Number of schedules deleted and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules" + ) + + deleted_count = await cleanup_orphaned_schedules_bulk( + request.schedule_ids, user.user_id + ) + + return ScheduleCleanupResponse( + success=deleted_count > 0, + deleted_count=deleted_count, + message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules", + ) + + +@router.post( + "/diagnostics/executions/stop-all-long-running", + response_model=StopExecutionResponse, + summary="Stop ALL Long-Running Executions", +) +async def stop_all_long_running_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions stopped and success message + """ + logger.info(f"Admin {user.user_id} stopping ALL long-running executions") + + stopped_count = await stop_all_long_running_executions(user.user_id) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} long-running executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup ALL Orphaned Executions", +) +async def cleanup_all_orphaned_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL orphaned executions (>24h old) by directly updating DB status. + Operates on all executions, not just paginated results. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions") + + # Fetch all orphaned execution IDs + execution_ids = await get_all_orphaned_execution_ids() + + if not execution_ids: + return StopExecutionResponse( + success=True, + stopped_count=0, + message="No orphaned executions to cleanup", + ) + + cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} orphaned executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-stuck-queued", + response_model=StopExecutionResponse, + summary="Cleanup ALL Stuck Queued Executions", +) +async def cleanup_all_stuck_queued_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions") + + cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} stuck queued executions", + ) + + +@router.post( + "/diagnostics/executions/requeue-all-stuck", + response_model=RequeueExecutionResponse, + summary="Requeue ALL Stuck Queued Executions", +) +async def requeue_all_stuck_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ. + Operates on all executions, not just paginated results. + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits. + + Returns: + Number of executions requeued and success message + """ + logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions") + + # Fetch all stuck queued execution IDs + execution_ids = await get_all_stuck_queued_execution_ids() + + if not execution_ids: + return RequeueExecutionResponse( + success=True, + requeued_count=0, + message="No stuck queued executions to requeue", + ) + + # Get stuck executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + # Requeue all in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} stuck executions", + ) diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py new file mode 100644 index 0000000000..a3783312b0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py @@ -0,0 +1,889 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import fastapi +import fastapi.testclient +import pytest +import pytest_mock +from autogpt_libs.auth.jwt_utils import get_jwt_payload +from prisma.enums import AgentExecutionStatus + +import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes +from backend.data.diagnostics import ( + AgentDiagnosticsSummary, + ExecutionDiagnosticsSummary, + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, +) +from backend.data.execution import GraphExecutionMeta + +app = fastapi.FastAPI() +app.include_router(diagnostics_admin_routes.router) + +client = fastapi.testclient.TestClient(app) + + +@pytest.fixture(autouse=True) +def setup_app_admin_auth(mock_jwt_admin): + """Setup admin auth overrides for all tests in this module""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"] + yield + app.dependency_overrides.clear() + + +def test_get_execution_diagnostics_success( + mocker: pytest_mock.MockFixture, +): + """Test fetching execution diagnostics with invalid state detection""" + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=1, + stuck_running_1h=3, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, # New invalid state + invalid_running_without_start=1, # New invalid state + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions") + + assert response.status_code == 200 + data = response.json() + + # Verify new invalid state fields are included + assert data["invalid_queued_with_start"] == 1 + assert data["invalid_running_without_start"] == 1 + # Verify all expected fields present + assert "running_executions" in data + assert "orphaned_running" in data + assert "failed_count_24h" in data + + +def test_list_invalid_executions( + mocker: pytest_mock.MockFixture, +): + """Test listing executions in invalid states (read-only endpoint)""" + mock_invalid_executions = [ + RunningExecutionDetail( + execution_id="exec-invalid-1", + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="QUEUED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now( + timezone.utc + ), # QUEUED but has startedAt - INVALID! + queue_status=None, + ), + RunningExecutionDetail( + execution_id="exec-invalid-2", + graph_id="graph-456", + graph_name="Another Graph", + graph_version=2, + user_id="user-456", + user_email="user@example.com", + status="RUNNING", + created_at=datetime.now(timezone.utc), + started_at=None, # RUNNING but no startedAt - INVALID! + queue_status=None, + ), + ] + + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=0, + orphaned_queued=0, + failed_count_1h=0, + failed_count_24h=0, + failure_rate_24h=0.0, + stuck_running_24h=0, + stuck_running_1h=0, + oldest_running_hours=None, + stuck_queued_1h=0, + queued_never_started=0, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=0, + completed_24h=0, + throughput_per_hour=0.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details", + return_value=mock_invalid_executions, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # Sum of both invalid state types + assert len(data["executions"]) == 2 + # Verify both types of invalid states are returned + assert data["executions"][0]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + assert data["executions"][1]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + + +def test_requeue_single_execution_with_add_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test requeueing uses add_graph_execution in requeue mode""" + mock_exec_meta = GraphExecutionMeta( + id="exec-stuck-123", + user_id="user-123", + graph_id="graph-456", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_add_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-stuck-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 1 + + # Verify it used add_graph_execution in requeue mode + mock_add_graph_execution.assert_called_once() + call_kwargs = mock_add_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode! + assert call_kwargs["graph_id"] == "graph-456" + assert call_kwargs["user_id"] == "user-123" + + +def test_stop_single_execution_with_stop_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test stopping uses robust stop_graph_execution""" + mock_exec_meta = GraphExecutionMeta( + id="exec-running-123", + user_id="user-789", + graph_id="graph-999", + graph_version=2, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_stop_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 1 + + # Verify it used stop_graph_execution with cascade + mock_stop_graph_execution.assert_called_once() + call_kwargs = mock_stop_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-running-123" + assert call_kwargs["user_id"] == "user-789" + assert call_kwargs["cascade"] is True # Stops children too! + assert call_kwargs["wait_timeout"] == 15.0 + + +def test_requeue_not_queued_execution_fails( + mocker: pytest_mock.MockFixture, +): + """Test that requeue fails if execution is not in QUEUED status""" + # Mock an execution that's RUNNING (not QUEUED) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], # No QUEUED executions found + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 404 + assert "not found or not in QUEUED status" in response.json()["detail"] + + +def test_list_invalid_executions_no_bulk_actions( + mocker: pytest_mock.MockFixture, +): + """Verify invalid executions endpoint is read-only (no bulk actions)""" + # This is a documentation test - the endpoint exists but should not + # have corresponding cleanup/stop/requeue endpoints + + # These endpoints should NOT exist for invalid states: + invalid_bulk_endpoints = [ + "/admin/diagnostics/executions/cleanup-invalid", + "/admin/diagnostics/executions/stop-invalid", + "/admin/diagnostics/executions/requeue-invalid", + ] + + for endpoint in invalid_bulk_endpoints: + response = client.post(endpoint, json={"execution_ids": ["test"]}) + assert response.status_code == 404, f"{endpoint} should not exist (read-only)" + + +def test_execution_ids_filter_efficiency( + mocker: pytest_mock.MockFixture, +): + """Test that bulk operations use efficient execution_ids filter""" + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + for i in range(3) + ] + + mock_get_graph_executions = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["exec-0", "exec-1", "exec-2"]}, + ) + + assert response.status_code == 200 + + # Verify it used execution_ids filter (not fetching all queued) + mock_get_graph_executions.assert_called_once() + call_kwargs = mock_get_graph_executions.call_args.kwargs + assert "execution_ids" in call_kwargs + assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"] + assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED] + + +# --------------------------------------------------------------------------- +# Helper: reusable mock diagnostics summary +# --------------------------------------------------------------------------- + + +def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary: + defaults = dict( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=3, + stuck_running_1h=5, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ExecutionDiagnosticsSummary(**defaults) + + +_SENTINEL = object() + + +def _make_mock_execution( + exec_id: str = "exec-1", + status: str = "RUNNING", + started_at: datetime | None | object = _SENTINEL, +) -> RunningExecutionDetail: + return RunningExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status=status, + created_at=datetime.now(timezone.utc), + started_at=( + datetime.now(timezone.utc) if started_at is _SENTINEL else started_at + ), + queue_status=None, + ) + + +def _make_mock_failed_execution( + exec_id: str = "exec-fail-1", +) -> FailedExecutionDetail: + return FailedExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="FAILED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now(timezone.utc), + failed_at=datetime.now(timezone.utc), + error_message="Something went wrong", + ) + + +def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics: + defaults = dict( + total_schedules=15, + user_schedules=10, + system_schedules=5, + orphaned_deleted_graph=2, + orphaned_no_library_access=1, + orphaned_invalid_credentials=0, + orphaned_validation_failed=0, + total_orphaned=3, + schedules_next_hour=4, + schedules_next_24h=8, + total_runs_next_hour=12, + total_runs_next_24h=48, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ScheduleHealthMetrics(**defaults) + + +# --------------------------------------------------------------------------- +# GET endpoints: execution list variants +# --------------------------------------------------------------------------- + + +def test_list_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-run-1"), + _make_mock_execution("exec-run-2"), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 15 # running_count(10) + queued_db_count(5) + assert len(data["executions"]) == 2 + assert data["executions"][0]["execution_id"] == "exec-run-1" + + +def test_list_orphaned_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1) + assert len(data["executions"]) == 1 + + +def test_list_failed_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_failed_execution("exec-fail-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count", + return_value=42, + ) + + response = client.get( + "/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 42 + assert len(data["executions"]) == 1 + assert data["executions"][0]["error_message"] == "Something went wrong" + + +def test_list_long_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-long-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/long-running?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # stuck_running_24h + assert len(data["executions"]) == 1 + + +def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/stuck-queued?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # stuck_queued_1h + assert len(data["executions"]) == 1 + + +# --------------------------------------------------------------------------- +# GET endpoints: agent + schedule diagnostics +# --------------------------------------------------------------------------- + + +def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture): + mock_diag = AgentDiagnosticsSummary( + agents_with_active_executions=7, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics", + return_value=mock_diag, + ) + + response = client.get("/admin/diagnostics/agents") + + assert response.status_code == 200 + data = response.json() + assert data["agents_with_active_executions"] == 7 + + +def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture): + mock_metrics = _make_mock_schedule_health() + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=mock_metrics, + ) + + response = client.get("/admin/diagnostics/schedules") + + assert response.status_code == 200 + data = response.json() + assert data["user_schedules"] == 10 + assert data["total_orphaned"] == 3 + assert data["total_runs_next_hour"] == 12 + + +def test_list_all_schedules(mocker: pytest_mock.MockFixture): + mock_schedules = [ + ScheduleDetail( + schedule_id="sched-1", + schedule_name="Daily Run", + graph_id="graph-1", + graph_name="My Agent", + graph_version=1, + user_id="user-1", + user_email="alice@example.com", + cron="0 9 * * *", + timezone="UTC", + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details", + return_value=mock_schedules, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=_make_mock_schedule_health(), + ) + + response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 10 + assert len(data["schedules"]) == 1 + assert data["schedules"][0]["schedule_name"] == "Daily Run" + + +def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture): + mock_orphans = [ + OrphanedScheduleDetail( + schedule_id="sched-orphan-1", + schedule_name="Ghost Schedule", + graph_id="graph-deleted", + graph_version=1, + user_id="user-1", + orphan_reason="deleted_graph", + error_detail=None, + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details", + return_value=mock_orphans, + ) + + response = client.get("/admin/diagnostics/schedules/orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["schedules"][0]["orphan_reason"] == "deleted_graph" + + +# --------------------------------------------------------------------------- +# POST endpoints: bulk stop, cleanup, requeue +# --------------------------------------------------------------------------- + + +def test_stop_multiple_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=None, + stats=None, + ) + for i in range(2) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["exec-0", "exec-1"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["stopped_count"] == 0 + + +def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=3, + ) + + response = client.post( + "/admin/diagnostics/executions/cleanup-orphaned", + json={"execution_ids": ["exec-1", "exec-2", "exec-3"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 3 + + +def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk", + return_value=2, + ) + + response = client.post( + "/admin/diagnostics/schedules/cleanup-orphaned", + json={"schedule_ids": ["sched-1", "sched-2"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["deleted_count"] == 2 + + +def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions", + return_value=5, + ) + + response = client.post("/admin/diagnostics/executions/stop-all-long-running") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 5 + + +def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=["exec-1", "exec-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=2, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 0 + assert "No orphaned" in data["message"] + + +def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions", + return_value=4, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 4 + + +def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-stuck-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=None, + ended_at=None, + stats=None, + ) + for i in range(3) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 3 + + +def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 0 + assert "No stuck" in data["message"] + + +def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["requeued_count"] == 0 + + +def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "nonexistent"}, + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] diff --git a/autogpt_platform/backend/backend/api/features/admin/model.py b/autogpt_platform/backend/backend/api/features/admin/model.py index 82f51e8e7a..c96c6d6433 100644 --- a/autogpt_platform/backend/backend/api/features/admin/model.py +++ b/autogpt_platform/backend/backend/api/features/admin/model.py @@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel): class AddUserCreditsResponse(BaseModel): new_balance: int transaction_key: str + + +class ExecutionDiagnosticsResponse(BaseModel): + """Response model for execution diagnostics""" + + # Current execution state + running_executions: int + queued_executions_db: int + queued_executions_rabbitmq: int + cancel_queue_depth: int + + # Orphaned execution detection + orphaned_running: int + orphaned_queued: int + + # Failure metrics + failed_count_1h: int + failed_count_24h: int + failure_rate_24h: float + + # Long-running detection + stuck_running_24h: int + stuck_running_1h: int + oldest_running_hours: float | None + + # Stuck queued detection + stuck_queued_1h: int + queued_never_started: int + + # Invalid state detection (data corruption - no auto-actions) + invalid_queued_with_start: int + invalid_running_without_start: int + + # Throughput metrics + completed_1h: int + completed_24h: int + throughput_per_hour: float + + timestamp: str + + +class AgentDiagnosticsResponse(BaseModel): + """Response model for agent diagnostics""" + + agents_with_active_executions: int + timestamp: str + + +class ScheduleHealthMetrics(BaseModel): + """Response model for schedule diagnostics""" + + total_schedules: int + user_schedules: int + system_schedules: int + + # Orphan detection + orphaned_deleted_graph: int + orphaned_no_library_access: int + orphaned_invalid_credentials: int + orphaned_validation_failed: int + total_orphaned: int + + # Upcoming + schedules_next_hour: int + schedules_next_24h: int + + timestamp: str diff --git a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py new file mode 100644 index 0000000000..048c4ae07e --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py @@ -0,0 +1,141 @@ +import logging +from datetime import datetime + +from autogpt_libs.auth import get_user_id, requires_admin_user +from fastapi import APIRouter, Query, Security +from pydantic import BaseModel + +from backend.data.platform_cost import ( + CostLogRow, + PlatformCostDashboard, + get_platform_cost_dashboard, + get_platform_cost_logs, + get_platform_cost_logs_for_export, +) +from backend.util.models import Pagination + +logger = logging.getLogger(__name__) + + +router = APIRouter( + prefix="/platform-costs", + tags=["platform-cost", "admin"], + dependencies=[Security(requires_admin_user)], +) + + +class PlatformCostLogsResponse(BaseModel): + logs: list[CostLogRow] + pagination: Pagination + + +@router.get( + "/dashboard", + response_model=PlatformCostDashboard, + summary="Get Platform Cost Dashboard", +) +async def get_cost_dashboard( + admin_user_id: str = Security(get_user_id), + start: datetime | None = Query(None), + end: datetime | None = Query(None), + provider: str | None = Query(None), + user_id: str | None = Query(None), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), + graph_exec_id: str | None = Query(None), +): + logger.info("Admin %s fetching platform cost dashboard", admin_user_id) + return await get_platform_cost_dashboard( + start=start, + end=end, + provider=provider, + user_id=user_id, + model=model, + block_name=block_name, + tracking_type=tracking_type, + graph_exec_id=graph_exec_id, + ) + + +@router.get( + "/logs", + response_model=PlatformCostLogsResponse, + summary="Get Platform Cost Logs", +) +async def get_cost_logs( + admin_user_id: str = Security(get_user_id), + start: datetime | None = Query(None), + end: datetime | None = Query(None), + provider: str | None = Query(None), + user_id: str | None = Query(None), + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), + graph_exec_id: str | None = Query(None), +): + logger.info("Admin %s fetching platform cost logs", admin_user_id) + logs, total = await get_platform_cost_logs( + start=start, + end=end, + provider=provider, + user_id=user_id, + page=page, + page_size=page_size, + model=model, + block_name=block_name, + tracking_type=tracking_type, + graph_exec_id=graph_exec_id, + ) + total_pages = (total + page_size - 1) // page_size + return PlatformCostLogsResponse( + logs=logs, + pagination=Pagination( + total_items=total, + total_pages=total_pages, + current_page=page, + page_size=page_size, + ), + ) + + +class PlatformCostExportResponse(BaseModel): + logs: list[CostLogRow] + total_rows: int + truncated: bool + + +@router.get( + "/logs/export", + response_model=PlatformCostExportResponse, + summary="Export Platform Cost Logs", +) +async def export_cost_logs( + admin_user_id: str = Security(get_user_id), + start: datetime | None = Query(None), + end: datetime | None = Query(None), + provider: str | None = Query(None), + user_id: str | None = Query(None), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), + graph_exec_id: str | None = Query(None), +): + logger.info("Admin %s exporting platform cost logs", admin_user_id) + logs, truncated = await get_platform_cost_logs_for_export( + start=start, + end=end, + provider=provider, + user_id=user_id, + model=model, + block_name=block_name, + tracking_type=tracking_type, + graph_exec_id=graph_exec_id, + ) + return PlatformCostExportResponse( + logs=logs, + total_rows=len(logs), + truncated=truncated, + ) diff --git a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py new file mode 100644 index 0000000000..8cfc0e47b5 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py @@ -0,0 +1,291 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import fastapi +import fastapi.testclient +import pytest +import pytest_mock +from autogpt_libs.auth.jwt_utils import get_jwt_payload + +from backend.data.platform_cost import CostLogRow, PlatformCostDashboard + +from .platform_cost_routes import router as platform_cost_router + +app = fastapi.FastAPI() +app.include_router(platform_cost_router) + +client = fastapi.testclient.TestClient(app) + + +@pytest.fixture(autouse=True) +def setup_app_admin_auth(mock_jwt_admin): + """Setup admin auth overrides for all tests in this module""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"] + yield + app.dependency_overrides.clear() + + +def test_get_dashboard_success( + mocker: pytest_mock.MockerFixture, +) -> None: + real_dashboard = PlatformCostDashboard( + by_provider=[], + by_user=[], + total_cost_microdollars=0, + total_requests=0, + total_users=0, + ) + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard", + AsyncMock(return_value=real_dashboard), + ) + + response = client.get("/platform-costs/dashboard") + assert response.status_code == 200 + data = response.json() + assert "by_provider" in data + assert "by_user" in data + assert data["total_cost_microdollars"] == 0 + + +def test_get_logs_success( + mocker: pytest_mock.MockerFixture, +) -> None: + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs", + AsyncMock(return_value=([], 0)), + ) + + response = client.get("/platform-costs/logs") + assert response.status_code == 200 + data = response.json() + assert data["logs"] == [] + assert data["pagination"]["total_items"] == 0 + + +def test_get_dashboard_with_filters( + mocker: pytest_mock.MockerFixture, +) -> None: + real_dashboard = PlatformCostDashboard( + by_provider=[], + by_user=[], + total_cost_microdollars=0, + total_requests=0, + total_users=0, + ) + mock_dashboard = AsyncMock(return_value=real_dashboard) + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard", + mock_dashboard, + ) + + response = client.get( + "/platform-costs/dashboard", + params={ + "start": "2026-01-01T00:00:00", + "end": "2026-04-01T00:00:00", + "provider": "openai", + "user_id": "test-user-123", + }, + ) + assert response.status_code == 200 + mock_dashboard.assert_called_once() + call_kwargs = mock_dashboard.call_args.kwargs + assert call_kwargs["provider"] == "openai" + assert call_kwargs["user_id"] == "test-user-123" + assert call_kwargs["start"] is not None + assert call_kwargs["end"] is not None + + +def test_get_logs_with_pagination( + mocker: pytest_mock.MockerFixture, +) -> None: + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs", + AsyncMock(return_value=([], 0)), + ) + + response = client.get( + "/platform-costs/logs", + params={"page": 2, "page_size": 25, "provider": "anthropic"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["pagination"]["current_page"] == 2 + assert data["pagination"]["page_size"] == 25 + + +def test_get_dashboard_requires_admin() -> None: + import fastapi + from fastapi import HTTPException + + def reject_jwt(request: fastapi.Request): + raise HTTPException(status_code=401, detail="Not authenticated") + + app.dependency_overrides[get_jwt_payload] = reject_jwt + try: + response = client.get("/platform-costs/dashboard") + assert response.status_code == 401 + response = client.get("/platform-costs/logs") + assert response.status_code == 401 + finally: + app.dependency_overrides.clear() + + +def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None: + """Non-admin JWT must be rejected with 403 by requires_admin_user.""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"] + try: + response = client.get("/platform-costs/dashboard") + assert response.status_code == 403 + response = client.get("/platform-costs/logs") + assert response.status_code == 403 + finally: + app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"] + + +def test_get_logs_invalid_page_size_too_large() -> None: + """page_size > 200 must be rejected with 422.""" + response = client.get("/platform-costs/logs", params={"page_size": 201}) + assert response.status_code == 422 + + +def test_get_logs_invalid_page_size_zero() -> None: + """page_size = 0 (below ge=1) must be rejected with 422.""" + response = client.get("/platform-costs/logs", params={"page_size": 0}) + assert response.status_code == 422 + + +def test_get_logs_invalid_page_negative() -> None: + """page < 1 must be rejected with 422.""" + response = client.get("/platform-costs/logs", params={"page": 0}) + assert response.status_code == 422 + + +def test_get_dashboard_invalid_date_format() -> None: + """Malformed start date must be rejected with 422.""" + response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"}) + assert response.status_code == 422 + + +def test_get_dashboard_repeated_requests( + mocker: pytest_mock.MockerFixture, +) -> None: + """Repeated requests to the dashboard route both return 200.""" + real_dashboard = PlatformCostDashboard( + by_provider=[], + by_user=[], + total_cost_microdollars=42, + total_requests=1, + total_users=1, + ) + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard", + AsyncMock(return_value=real_dashboard), + ) + + r1 = client.get("/platform-costs/dashboard") + r2 = client.get("/platform-costs/dashboard") + + assert r1.status_code == 200 + assert r2.status_code == 200 + assert r1.json()["total_cost_microdollars"] == 42 + assert r2.json()["total_cost_microdollars"] == 42 + + +def _make_cost_log_row() -> CostLogRow: + return CostLogRow( + id="log-1", + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + user_id="user-1", + email="u***@example.com", + graph_exec_id="graph-1", + node_exec_id="node-1", + block_name="LlmCallBlock", + provider="anthropic", + tracking_type="token", + cost_microdollars=500, + input_tokens=100, + output_tokens=50, + cache_read_tokens=10, + cache_creation_tokens=5, + duration=1.5, + model="claude-3-5-sonnet-20241022", + ) + + +def test_export_logs_success( + mocker: pytest_mock.MockerFixture, +) -> None: + row = _make_cost_log_row() + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + AsyncMock(return_value=([row], False)), + ) + + response = client.get("/platform-costs/logs/export") + assert response.status_code == 200 + data = response.json() + assert data["total_rows"] == 1 + assert data["truncated"] is False + assert len(data["logs"]) == 1 + assert data["logs"][0]["cache_read_tokens"] == 10 + assert data["logs"][0]["cache_creation_tokens"] == 5 + + +def test_export_logs_truncated( + mocker: pytest_mock.MockerFixture, +) -> None: + rows = [_make_cost_log_row() for _ in range(3)] + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + AsyncMock(return_value=(rows, True)), + ) + + response = client.get("/platform-costs/logs/export") + assert response.status_code == 200 + data = response.json() + assert data["total_rows"] == 3 + assert data["truncated"] is True + + +def test_export_logs_with_filters( + mocker: pytest_mock.MockerFixture, +) -> None: + mock_export = AsyncMock(return_value=([], False)) + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + mock_export, + ) + + response = client.get( + "/platform-costs/logs/export", + params={ + "provider": "anthropic", + "model": "claude-3-5-sonnet-20241022", + "block_name": "LlmCallBlock", + "tracking_type": "token", + }, + ) + assert response.status_code == 200 + mock_export.assert_called_once() + call_kwargs = mock_export.call_args.kwargs + assert call_kwargs["provider"] == "anthropic" + assert call_kwargs["model"] == "claude-3-5-sonnet-20241022" + assert call_kwargs["block_name"] == "LlmCallBlock" + assert call_kwargs["tracking_type"] == "token" + + +def test_export_logs_requires_admin() -> None: + import fastapi + from fastapi import HTTPException + + def reject_jwt(request: fastapi.Request): + raise HTTPException(status_code=401, detail="Not authenticated") + + app.dependency_overrides[get_jwt_payload] = reject_jwt + try: + response = client.get("/platform-costs/logs/export") + assert response.status_code == 401 + finally: + app.dependency_overrides.clear() diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py index 49caada729..3b9c762f21 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py @@ -9,11 +9,14 @@ from pydantic import BaseModel from backend.copilot.config import ChatConfig from backend.copilot.rate_limit import ( + SubscriptionTier, get_global_rate_limits, get_usage_status, + get_user_tier, reset_user_usage, + set_user_tier, ) -from backend.data.user import get_user_by_email, get_user_email_by_id +from backend.data.user import get_user_by_email, get_user_email_by_id, search_users logger = logging.getLogger(__name__) @@ -29,10 +32,21 @@ router = APIRouter( class UserRateLimitResponse(BaseModel): user_id: str user_email: Optional[str] = None - daily_token_limit: int - weekly_token_limit: int - daily_tokens_used: int - weekly_tokens_used: int + daily_cost_limit_microdollars: int + weekly_cost_limit_microdollars: int + daily_cost_used_microdollars: int + weekly_cost_used_microdollars: int + tier: SubscriptionTier + + +class UserTierResponse(BaseModel): + user_id: str + tier: SubscriptionTier + + +class SetUserTierRequest(BaseModel): + user_id: str + tier: SubscriptionTier async def _resolve_user_id( @@ -86,18 +100,21 @@ async def get_user_rate_limit( logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id) - daily_limit, weekly_limit = await get_global_rate_limits( - resolved_id, config.daily_token_limit, config.weekly_token_limit + daily_limit, weekly_limit, tier = await get_global_rate_limits( + resolved_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - usage = await get_usage_status(resolved_id, daily_limit, weekly_limit) + usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier) return UserRateLimitResponse( user_id=resolved_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, + tier=tier, ) @@ -125,10 +142,12 @@ async def reset_user_rate_limit( logger.exception("Failed to reset user usage") raise HTTPException(status_code=500, detail="Failed to reset usage") from e - daily_limit, weekly_limit = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + daily_limit, weekly_limit, tier = await get_global_rate_limits( + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - usage = await get_usage_status(user_id, daily_limit, weekly_limit) + usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier) try: resolved_email = await get_user_email_by_id(user_id) @@ -139,8 +158,106 @@ async def reset_user_rate_limit( return UserRateLimitResponse( user_id=user_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, + tier=tier, ) + + +@router.get( + "/rate_limit/tier", + response_model=UserTierResponse, + summary="Get User Rate Limit Tier", +) +async def get_user_rate_limit_tier( + user_id: str, + admin_user_id: str = Security(get_user_id), +) -> UserTierResponse: + """Get a user's current rate-limit tier. Admin-only. + + Returns 404 if the user does not exist in the database. + """ + logger.info("Admin %s checking tier for user %s", admin_user_id, user_id) + + resolved_email = await get_user_email_by_id(user_id) + if resolved_email is None: + raise HTTPException(status_code=404, detail=f"User {user_id} not found") + + tier = await get_user_tier(user_id) + return UserTierResponse(user_id=user_id, tier=tier) + + +@router.post( + "/rate_limit/tier", + response_model=UserTierResponse, + summary="Set User Rate Limit Tier", +) +async def set_user_rate_limit_tier( + request: SetUserTierRequest, + admin_user_id: str = Security(get_user_id), +) -> UserTierResponse: + """Set a user's rate-limit tier. Admin-only. + + Returns 404 if the user does not exist in the database. + """ + try: + resolved_email = await get_user_email_by_id(request.user_id) + except Exception: + logger.warning( + "Failed to resolve email for user %s", + request.user_id, + exc_info=True, + ) + resolved_email = None + + if resolved_email is None: + raise HTTPException(status_code=404, detail=f"User {request.user_id} not found") + + old_tier = await get_user_tier(request.user_id) + logger.info( + "Admin %s changing tier for user %s (%s): %s -> %s", + admin_user_id, + request.user_id, + resolved_email, + old_tier.value, + request.tier.value, + ) + try: + await set_user_tier(request.user_id, request.tier) + except Exception as e: + logger.exception("Failed to set user tier") + raise HTTPException(status_code=500, detail="Failed to set tier") from e + + return UserTierResponse(user_id=request.user_id, tier=request.tier) + + +class UserSearchResult(BaseModel): + user_id: str + user_email: Optional[str] = None + + +@router.get( + "/rate_limit/search_users", + response_model=list[UserSearchResult], + summary="Search Users by Name or Email", +) +async def admin_search_users( + query: str, + limit: int = 20, + admin_user_id: str = Security(get_user_id), +) -> list[UserSearchResult]: + """Search users by partial email or name. Admin-only. + + Queries the User table directly — returns results even for users + without credit transaction history. + """ + if len(query.strip()) < 3: + raise HTTPException( + status_code=400, + detail="Search query must be at least 3 characters.", + ) + logger.info("Admin %s searching users with query=%r", admin_user_id, query) + results = await search_users(query, limit=max(1, min(limit, 50))) + return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results] diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py index 6560715b63..c6c920829d 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py @@ -9,7 +9,7 @@ import pytest_mock from autogpt_libs.auth.jwt_utils import get_jwt_payload from pytest_snapshot.plugin import Snapshot -from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow +from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow from .rate_limit_admin_routes import router as rate_limit_admin_router @@ -57,7 +57,7 @@ def _patch_rate_limit_deps( mocker.patch( f"{_MOCK_MODULE}.get_global_rate_limits", new_callable=AsyncMock, - return_value=(2_500_000, 12_500_000), + return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE), ) mocker.patch( f"{_MOCK_MODULE}.get_usage_status", @@ -85,10 +85,11 @@ def test_get_rate_limit( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 - assert data["weekly_token_limit"] == 12_500_000 - assert data["daily_tokens_used"] == 500_000 - assert data["weekly_tokens_used"] == 3_000_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 + assert data["weekly_cost_limit_microdollars"] == 12_500_000 + assert data["daily_cost_used_microdollars"] == 500_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 + assert data["tier"] == "FREE" configured_snapshot.assert_match( json.dumps(data, indent=2, sort_keys=True) + "\n", @@ -116,7 +117,7 @@ def test_get_rate_limit_by_email( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 def test_get_rate_limit_by_email_not_found( @@ -159,9 +160,10 @@ def test_reset_user_usage_daily_only( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 # Weekly is untouched - assert data["weekly_tokens_used"] == 3_000_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 + assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False) @@ -190,8 +192,9 @@ def test_reset_user_usage_daily_and_weekly( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 - assert data["weekly_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 + assert data["weekly_cost_used_microdollars"] == 0 + assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True) @@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure( mocker.patch( f"{_MOCK_MODULE}.get_global_rate_limits", new_callable=AsyncMock, - return_value=(2_500_000, 12_500_000), + return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE), ) mocker.patch( f"{_MOCK_MODULE}.get_usage_status", @@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None: json={"user_id": "test"}, ) assert response.status_code == 403 + + +# --------------------------------------------------------------------------- +# Tier management endpoints +# --------------------------------------------------------------------------- + + +def test_get_user_tier( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test getting a user's rate-limit tier.""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TARGET_EMAIL, + ) + mocker.patch( + f"{_MOCK_MODULE}.get_user_tier", + new_callable=AsyncMock, + return_value=SubscriptionTier.PRO, + ) + + response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id}) + + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == target_user_id + assert data["tier"] == "PRO" + + +def test_get_user_tier_user_not_found( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test that getting tier for a non-existent user returns 404.""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id}) + + assert response.status_code == 404 + + +def test_set_user_tier( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test setting a user's rate-limit tier (upgrade).""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TARGET_EMAIL, + ) + mocker.patch( + f"{_MOCK_MODULE}.get_user_tier", + new_callable=AsyncMock, + return_value=SubscriptionTier.FREE, + ) + mock_set = mocker.patch( + f"{_MOCK_MODULE}.set_user_tier", + new_callable=AsyncMock, + ) + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "ENTERPRISE"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == target_user_id + assert data["tier"] == "ENTERPRISE" + mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE) + + +def test_set_user_tier_downgrade( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test downgrading a user's tier from PRO to FREE.""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TARGET_EMAIL, + ) + mocker.patch( + f"{_MOCK_MODULE}.get_user_tier", + new_callable=AsyncMock, + return_value=SubscriptionTier.PRO, + ) + mock_set = mocker.patch( + f"{_MOCK_MODULE}.set_user_tier", + new_callable=AsyncMock, + ) + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "FREE"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == target_user_id + assert data["tier"] == "FREE" + mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE) + + +def test_set_user_tier_invalid_tier( + target_user_id: str, +) -> None: + """Test that setting an invalid tier returns 422.""" + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "invalid"}, + ) + + assert response.status_code == 422 + + +def test_set_user_tier_invalid_tier_uppercase( + target_user_id: str, +) -> None: + """Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422. + + Regression: ensures Pydantic enum validation rejects values that are not + members of SubscriptionTier, even when they look like valid enum names. + """ + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "INVALID"}, + ) + + assert response.status_code == 422 + body = response.json() + assert "detail" in body + + +def test_set_user_tier_email_lookup_failure_returns_404( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test that email lookup failure returns 404 (user unverifiable).""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + side_effect=Exception("DB connection failed"), + ) + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "PRO"}, + ) + + assert response.status_code == 404 + + +def test_set_user_tier_user_not_found( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test that setting tier for a non-existent user returns 404.""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "PRO"}, + ) + + assert response.status_code == 404 + + +def test_set_user_tier_db_failure( + mocker: pytest_mock.MockerFixture, + target_user_id: str, +) -> None: + """Test that DB failure on set tier returns 500.""" + mocker.patch( + f"{_MOCK_MODULE}.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TARGET_EMAIL, + ) + mocker.patch( + f"{_MOCK_MODULE}.get_user_tier", + new_callable=AsyncMock, + return_value=SubscriptionTier.FREE, + ) + mocker.patch( + f"{_MOCK_MODULE}.set_user_tier", + new_callable=AsyncMock, + side_effect=Exception("DB connection refused"), + ) + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": target_user_id, "tier": "PRO"}, + ) + + assert response.status_code == 500 + + +def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None: + """Test that tier admin endpoints require admin role.""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"] + + response = client.get("/admin/rate_limit/tier", params={"user_id": "test"}) + assert response.status_code == 403 + + response = client.post( + "/admin/rate_limit/tier", + json={"user_id": "test", "tier": "PRO"}, + ) + assert response.status_code == 403 + + +# ─── search_users endpoint ────────────────────────────────────────── + + +def test_search_users_returns_matching_users( + mocker: pytest_mock.MockerFixture, + admin_user_id: str, +) -> None: + """Partial search should return all matching users from the User table.""" + mocker.patch( + _MOCK_MODULE + ".search_users", + new_callable=AsyncMock, + return_value=[ + ("user-1", "zamil.majdy@gmail.com"), + ("user-2", "zamil.majdy@agpt.co"), + ], + ) + + response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"}) + + assert response.status_code == 200 + results = response.json() + assert len(results) == 2 + assert results[0]["user_email"] == "zamil.majdy@gmail.com" + assert results[1]["user_email"] == "zamil.majdy@agpt.co" + + +def test_search_users_empty_results( + mocker: pytest_mock.MockerFixture, + admin_user_id: str, +) -> None: + """Search with no matches returns empty list.""" + mocker.patch( + _MOCK_MODULE + ".search_users", + new_callable=AsyncMock, + return_value=[], + ) + + response = client.get( + "/admin/rate_limit/search_users", params={"query": "nonexistent"} + ) + + assert response.status_code == 200 + assert response.json() == [] + + +def test_search_users_short_query_rejected( + admin_user_id: str, +) -> None: + """Query shorter than 3 characters should return 400.""" + response = client.get("/admin/rate_limit/search_users", params={"query": "ab"}) + assert response.status_code == 400 + + +def test_search_users_negative_limit_clamped( + mocker: pytest_mock.MockerFixture, + admin_user_id: str, +) -> None: + """Negative limit should be clamped to 1, not passed through.""" + mock_search = mocker.patch( + _MOCK_MODULE + ".search_users", + new_callable=AsyncMock, + return_value=[], + ) + + response = client.get( + "/admin/rate_limit/search_users", params={"query": "test", "limit": -1} + ) + + assert response.status_code == 200 + mock_search.assert_awaited_once_with("test", limit=1) + + +def test_search_users_requires_admin_role(mock_jwt_user) -> None: + """Test that the search_users endpoint requires admin role.""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"] + + response = client.get("/admin/rate_limit/search_users", params={"query": "test"}) + assert response.status_code == 403 diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index a4d61688f3..ca7e4355f6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -2,20 +2,20 @@ import asyncio import logging -import re from collections.abc import AsyncGenerator from typing import Annotated from uuid import uuid4 from autogpt_libs import auth from fastapi import APIRouter, HTTPException, Query, Response, Security -from fastapi.responses import StreamingResponse -from prisma.models import UserWorkspaceFile +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ConfigDict, Field, field_validator from backend.copilot import service as chat_service from backend.copilot import stream_registry -from backend.copilot.config import ChatConfig +from backend.copilot.builder_context import resolve_session_permissions +from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode +from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn from backend.copilot.model import ( ChatMessage, @@ -25,11 +25,18 @@ from backend.copilot.model import ( create_chat_session, delete_chat_session, get_chat_session, + get_or_create_builder_session, get_user_sessions, update_session_title, ) +from backend.copilot.pending_message_helpers import ( + QueuePendingMessageResponse, + is_turn_in_flight, + queue_pending_for_http, +) +from backend.copilot.pending_messages import peek_pending_messages from backend.copilot.rate_limit import ( - CoPilotUsageStatus, + CoPilotUsagePublic, RateLimitExceeded, acquire_reset_lock, check_rate_limit, @@ -41,6 +48,7 @@ from backend.copilot.rate_limit import ( reset_daily_usage, ) from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat +from backend.copilot.service import strip_injected_context_for_display from backend.copilot.tools.e2b_sandbox import kill_sandbox from backend.copilot.tools.models import ( AgentDetailsResponse, @@ -59,6 +67,10 @@ from backend.copilot.tools.models import ( InputValidationErrorResponse, MCPToolOutputResponse, MCPToolsDiscoveredResponse, + MemoryForgetCandidatesResponse, + MemoryForgetConfirmResponse, + MemorySearchResponse, + MemoryStoreResponse, NeedLoginResponse, NoResultsResponse, SetupRequirementsResponse, @@ -69,7 +81,7 @@ from backend.copilot.tracking import track_user_message from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.redis_client import get_redis_async from backend.data.understanding import get_business_understanding -from backend.data.workspace import get_or_create_workspace +from backend.data.workspace import build_files_block, resolve_workspace_files from backend.util.exceptions import InsufficientBalanceError, NotFoundError from backend.util.settings import Settings @@ -79,10 +91,6 @@ logger = logging.getLogger(__name__) config = ChatConfig() -_UUID_RE = re.compile( - r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I -) - async def _validate_and_get_session( session_id: str, @@ -99,30 +107,91 @@ router = APIRouter( tags=["chat"], ) + +def _strip_injected_context(message: dict) -> dict: + """Hide server-injected context blocks from the API response. + + Returns a **shallow copy** of *message* with all server-injected XML + blocks removed from ``content`` (if applicable). The original dict is + never mutated, so callers can safely pass live session dicts without + risking side-effects. + + Handles all three injected block types — ````, + ````, and ```` — regardless of the order they + appear at the start of the message. Only ``user``-role messages with + string content are touched; assistant / multimodal blocks pass through + unchanged. + """ + if message.get("role") == "user" and isinstance(message.get("content"), str): + result = message.copy() + result["content"] = strip_injected_context_for_display(message["content"]) + return result + return message + + # ========== Request/Response Models ========== class StreamChatRequest(BaseModel): """Request model for streaming chat with optional context.""" - message: str + message: str = Field(max_length=64_000) is_user_message: bool = True context: dict[str, str] | None = None # {url: str, content: str} file_ids: list[str] | None = Field( default=None, max_length=20 ) # Workspace file IDs attached to this message + mode: CopilotMode | None = Field( + default=None, + description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. " + "If None, uses the server default (extended_thinking).", + ) + model: CopilotLlmModel | None = Field( + default=None, + description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. " + "If None, the server applies per-user LD targeting then falls back to config.", + ) + + +class PeekPendingMessagesResponse(BaseModel): + """Response for the pending-message peek (GET) endpoint. + + Returns a read-only view of the pending buffer — messages are NOT + consumed. The frontend uses this to restore the queued-message + indicator after a page refresh and to decide when to clear it once + a turn has ended. + """ + + messages: list[str] + count: int class CreateSessionRequest(BaseModel): - """Request model for creating a new chat session. + """Request model for creating (or get-or-creating) a chat session. + + Two modes, selected by the body: + + - Default: create a fresh session. ``dry_run`` is a **top-level** + field — do not nest it inside ``metadata``. + - Builder-bound: when ``builder_graph_id`` is set, the endpoint + switches to **get-or-create** keyed on + ``(user_id, builder_graph_id)``. The builder panel calls this on + mount so the chat persists across refreshes. Graph ownership is + validated inside :func:`get_or_create_builder_session`. Write-side + scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject + any ``agent_id`` other than the bound graph) and a small blacklist + hides tools that conflict with the panel's scope + (``create_agent`` / ``customize_agent`` / ``get_agent_building_guide`` + — see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups + (``find_block``, ``find_agent``, ``search_docs``, …) stay open. - ``dry_run`` is a **top-level** field — do not nest it inside ``metadata``. Extra/unknown fields are rejected (422) to prevent silent mis-use. """ model_config = ConfigDict(extra="forbid") dry_run: bool = False + builder_graph_id: str | None = Field(default=None, max_length=128) class CreateSessionResponse(BaseModel): @@ -150,6 +219,8 @@ class SessionDetailResponse(BaseModel): user_id: str | None messages: list[dict] active_stream: ActiveStreamInfo | None = None # Present if stream is still active + has_more_messages: bool = False + oldest_sequence: int | None = None total_prompt_tokens: int = 0 total_completion_tokens: int = 0 metadata: ChatSessionMetadata = ChatSessionMetadata() @@ -265,29 +336,43 @@ async def create_session( user_id: Annotated[str, Security(auth.get_user_id)], request: CreateSessionRequest | None = None, ) -> CreateSessionResponse: - """ - Create a new chat session. + """Create (or get-or-create) a chat session. - Initiates a new chat session for the authenticated user. + Two modes, selected by the request body: + + - Default: create a fresh session for the user. ``dry_run=True`` forces + run_block and run_agent calls to use dry-run simulation. + - Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed + on ``(user_id, builder_graph_id)``. Returns the existing session for + that graph or creates one locked to it. Graph ownership is validated + inside :func:`get_or_create_builder_session`; raises 404 on + unauthorized access. Write-side scope is enforced per-tool + (``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than + the bound graph) and a small blacklist hides tools that conflict + with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`). Args: user_id: The authenticated user ID parsed from the JWT (required). - request: Optional request body. When provided, ``dry_run=True`` - forces run_block and run_agent calls to use dry-run simulation. + request: Optional request body with ``dry_run`` and/or + ``builder_graph_id``. Returns: - CreateSessionResponse: Details of the created session. - + CreateSessionResponse: Details of the resulting session. """ dry_run = request.dry_run if request else False + builder_graph_id = request.builder_graph_id if request else None logger.info( f"Creating session with user_id: " f"...{user_id[-8:] if len(user_id) > 8 else ''}" f"{', dry_run=True' if dry_run else ''}" + f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}" ) - session = await create_chat_session(user_id, dry_run=dry_run) + if builder_graph_id: + session = await get_or_create_builder_session(user_id, builder_graph_id) + else: + session = await create_chat_session(user_id, dry_run=dry_run) return CreateSessionResponse( id=session.session_id, @@ -346,6 +431,31 @@ async def delete_session( return Response(status_code=204) +@router.delete( + "/sessions/{session_id}/stream", + dependencies=[Security(auth.requires_user)], + status_code=204, +) +async def disconnect_session_stream( + session_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> Response: + """Disconnect all active SSE listeners for a session. + + Called by the frontend when the user switches away from a chat so the + backend releases XREAD listeners immediately rather than waiting for + the 5-10 s timeout. + """ + session = await get_chat_session(session_id, user_id) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session {session_id} not found or access denied", + ) + await stream_registry.disconnect_all_listeners(session_id) + return Response(status_code=204) + + @router.patch( "/sessions/{session_id}/title", summary="Update session title", @@ -389,60 +499,67 @@ async def update_session_title_route( async def get_session( session_id: str, user_id: Annotated[str, Security(auth.get_user_id)], + limit: int = Query(default=50, ge=1, le=200), + before_sequence: int | None = Query(default=None, ge=0), ) -> SessionDetailResponse: """ Retrieve the details of a specific chat session. - Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. - If there's an active stream for this session, returns active_stream info for reconnection. - - Args: - session_id: The unique identifier for the desired chat session. - user_id: The optional authenticated user ID, or None for anonymous access. - - Returns: - SessionDetailResponse: Details for the requested session, including active_stream info if applicable. - + Supports cursor-based pagination via ``limit`` and ``before_sequence``. + When no pagination params are provided, returns the most recent messages. """ - session = await get_chat_session(session_id, user_id) - if not session: + page = await get_chat_messages_paginated( + session_id, limit, before_sequence, user_id=user_id + ) + if page is None: raise NotFoundError(f"Session {session_id} not found.") - messages = [message.model_dump() for message in session.messages] + messages = [ + _strip_injected_context(message.model_dump()) for message in page.messages + ] - # Check if there's an active stream for this session + # Only check active stream on initial load (not on "load more" requests) active_stream_info = None - active_session, last_message_id = await stream_registry.get_active_session( - session_id, user_id - ) - logger.info( - f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, " - f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" - ) - if active_session: - # Keep the assistant message (including tool_calls) so the frontend can - # render the correct tool UI (e.g. CreateAgent with mini game). - # convertChatSessionToUiMessages handles isComplete=false by setting - # tool parts without output to state "input-available". - active_stream_info = ActiveStreamInfo( - turn_id=active_session.turn_id, - last_message_id=last_message_id, + if before_sequence is None: + active_session, last_message_id = await stream_registry.get_active_session( + session_id, user_id + ) + if active_session: + active_stream_info = ActiveStreamInfo( + turn_id=active_session.turn_id, + last_message_id=last_message_id, + ) + + # Skip session metadata on "load more" — frontend only needs messages + if before_sequence is not None: + return SessionDetailResponse( + id=page.session.session_id, + created_at=page.session.started_at.isoformat(), + updated_at=page.session.updated_at.isoformat(), + user_id=page.session.user_id or None, + messages=messages, + active_stream=None, + has_more_messages=page.has_more, + oldest_sequence=page.oldest_sequence, + total_prompt_tokens=0, + total_completion_tokens=0, ) - # Sum token usage from session - total_prompt = sum(u.prompt_tokens for u in session.usage) - total_completion = sum(u.completion_tokens for u in session.usage) + total_prompt = sum(u.prompt_tokens for u in page.session.usage) + total_completion = sum(u.completion_tokens for u in page.session.usage) return SessionDetailResponse( - id=session.session_id, - created_at=session.started_at.isoformat(), - updated_at=session.updated_at.isoformat(), - user_id=session.user_id or None, + id=page.session.session_id, + created_at=page.session.started_at.isoformat(), + updated_at=page.session.updated_at.isoformat(), + user_id=page.session.user_id or None, messages=messages, active_stream=active_stream_info, + has_more_messages=page.has_more, + oldest_sequence=page.oldest_sequence, total_prompt_tokens=total_prompt, total_completion_tokens=total_completion, - metadata=session.metadata, + metadata=page.session.metadata, ) @@ -451,21 +568,27 @@ async def get_session( ) async def get_copilot_usage( user_id: Annotated[str, Security(auth.get_user_id)], -) -> CoPilotUsageStatus: +) -> CoPilotUsagePublic: """Get CoPilot usage status for the authenticated user. - Returns current token usage vs limits for daily and weekly windows. - Global defaults sourced from LaunchDarkly (falling back to config). + Returns the percentage of the daily/weekly allowance used — not the + raw spend or cap — so clients cannot derive per-turn cost or platform + margins. Global defaults sourced from LaunchDarkly (falling back to + config). Includes the user's rate-limit tier. """ - daily_limit, weekly_limit = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + daily_limit, weekly_limit, tier = await get_global_rate_limits( + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - return await get_usage_status( + status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, + tier=tier, ) + return CoPilotUsagePublic.from_status(status) class RateLimitResetResponse(BaseModel): @@ -474,7 +597,9 @@ class RateLimitResetResponse(BaseModel): success: bool credits_charged: int = Field(description="Credits charged (in cents)") remaining_balance: int = Field(description="Credit balance after charge (in cents)") - usage: CoPilotUsageStatus = Field(description="Updated usage status after reset") + usage: CoPilotUsagePublic = Field( + description="Updated usage status after reset (percentages only)" + ) @router.post( @@ -498,7 +623,7 @@ async def reset_copilot_usage( ) -> RateLimitResetResponse: """Reset the daily CoPilot rate limit by spending credits. - Allows users who have hit their daily token limit to spend credits + Allows users who have hit their daily cost limit to spend credits to reset their daily usage counter and continue working. Returns 400 if the feature is disabled or the user is not over the limit. Returns 402 if the user has insufficient credits. @@ -516,8 +641,10 @@ async def reset_copilot_usage( detail="Rate limit reset is not available (credit system is disabled).", ) - daily_limit, weekly_limit = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + daily_limit, weekly_limit, tier = await get_global_rate_limits( + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) if daily_limit <= 0: @@ -554,8 +681,9 @@ async def reset_copilot_usage( # used for limit checks, not returned to the client.) usage_status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, + tier=tier, ) if daily_limit > 0 and usage_status.daily.used < daily_limit: raise HTTPException( @@ -589,7 +717,7 @@ async def reset_copilot_usage( # Reset daily usage in Redis. If this fails, refund the credits # so the user is not charged for a service they did not receive. - if not await reset_daily_usage(user_id, daily_token_limit=daily_limit): + if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit): # Compensate: refund the charged credits. refunded = False try: @@ -625,19 +753,20 @@ async def reset_copilot_usage( finally: await release_reset_lock(user_id) - # Return updated usage status. + # Return updated usage status (public schema — percentages only). updated_usage = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, + tier=tier, ) return RateLimitResetResponse( success=True, credits_charged=cost, remaining_balance=remaining, - usage=updated_usage, + usage=CoPilotUsagePublic.from_status(updated_usage), ) @@ -688,36 +817,52 @@ async def cancel_session_task( @router.post( "/sessions/{session_id}/stream", + responses={ + 202: { + "model": QueuePendingMessageResponse, + "description": ( + "Session has a turn in flight — message queued into the pending " + "buffer and will be picked up between tool-call rounds by the " + "executor currently processing the turn." + ), + }, + 404: {"description": "Session not found or access denied"}, + 429: {"description": "Cost rate-limit or call-frequency cap exceeded"}, + }, ) async def stream_chat_post( session_id: str, request: StreamChatRequest, user_id: str = Security(auth.get_user_id), ): - """ - Stream chat responses for a session (POST with context support). + """Start a new turn OR queue a follow-up — decided server-side. - Streams the AI/completion responses in real time over Server-Sent Events (SSE), including: - - Text fragments as they are generated - - Tool call UI elements (if invoked) - - Tool execution results + - **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``) + with Vercel AI SDK chunks (text fragments, tool-call UI, tool results). + The generation runs in a background task that survives client disconnects; + reconnect via ``GET /sessions/{session_id}/stream`` to resume. - The AI generation runs in a background task that continues even if the client disconnects. - All chunks are written to a per-turn Redis stream for reconnection support. If the client - disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume. + - **Session has a turn in flight**: pushes the message into the per-session + pending buffer and returns ``202 application/json`` with + ``QueuePendingMessageResponse``. The executor running the current turn + drains the buffer between tool-call rounds (baseline) or at the start of + the next turn (SDK). Clients should detect the 202 and surface the + message as a queued-chip in the UI. Args: - session_id: The chat session identifier to associate with the streamed messages. - request: Request body containing message, is_user_message, and optional context. + session_id: The chat session identifier. + request: Request body with message, is_user_message, and optional context. user_id: Authenticated user ID. - Returns: - StreamingResponse: SSE-formatted response chunks. - """ import asyncio import time stream_start_time = time.perf_counter() + # Wall-clock arrival time, propagated to the executor so the turn-start + # drain can order pending messages relative to this request (pending + # pushed BEFORE this instant were typed earlier; pending pushed AFTER + # are race-path follow-ups typed while /stream was still processing). + request_arrival_at = time.time() log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id} logger.info( @@ -725,7 +870,28 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - await _validate_and_get_session(session_id, user_id) + session = await _validate_and_get_session(session_id, user_id) + builder_permissions = resolve_session_permissions(session) + + # Self-defensive queue-fallback: if a turn is already running, don't race + # it on the cluster lock — drop the message into the pending buffer and + # return 202 so the caller can render a chip. Both UI chips and autopilot + # block follow-ups route through this path; keeping the decision on the + # server means every caller gets uniform behaviour. + if ( + request.is_user_message + and request.message + and await is_turn_in_flight(session_id) + ): + response = await queue_pending_for_http( + session_id=session_id, + user_id=user_id, + message=request.message, + context=request.context, + file_ids=request.file_ids, + ) + return JSONResponse(status_code=202, content=response.model_dump()) + logger.info( f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ @@ -736,18 +902,20 @@ async def stream_chat_post( }, ) - # Pre-turn rate limit check (token-based). + # Pre-turn rate limit check (cost-based, microdollars). # check_rate_limit short-circuits internally when both limits are 0. # Global defaults sourced from LaunchDarkly, falling back to config. if user_id: try: - daily_limit, weekly_limit = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + daily_limit, weekly_limit, _ = await get_global_rate_limits( + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) await check_rate_limit( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, ) except RateLimitExceeded as e: raise HTTPException(status_code=429, detail=str(e)) from e @@ -756,87 +924,75 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None - if request.file_ids and user_id: - # Filter to valid UUIDs only to prevent DB abuse - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - - if valid_ids: - workspace = await get_or_create_workspace(user_id) - # Batch query instead of N+1 - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } - ) - # Only keep IDs that actually exist in the user's workspace - sanitized_file_ids = [wf.id for wf in files] or None - file_lines: list[str] = [ - f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" - for wf in files - ] - if file_lines: - files_block = ( - "\n\n[Attached files]\n" - + "\n".join(file_lines) - + "\nUse read_workspace_file with the file_id to access file contents." - ) - request.message += files_block + if request.file_ids: + files = await resolve_workspace_files(user_id, request.file_ids) + sanitized_file_ids = [wf.id for wf in files] or None + request.message += build_files_block(files) # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't - # saved yet. append_and_save_message re-fetches inside a lock to prevent - # message loss from concurrent requests. + # saved yet. append_and_save_message returns None when a duplicate is + # detected — in that case skip enqueue to avoid processing the message twice. + is_duplicate_message = False if request.message: message = ChatMessage( role="user" if request.is_user_message else "assistant", content=request.message, ) - if request.is_user_message: + logger.info(f"[STREAM] Saving user message to session {session_id}") + is_duplicate_message = ( + await append_and_save_message(session_id, message) + ) is None + logger.info(f"[STREAM] User message saved for session {session_id}") + if not is_duplicate_message and request.is_user_message: track_user_message( user_id=user_id, session_id=session_id, message_length=len(request.message), ) - logger.info(f"[STREAM] Saving user message to session {session_id}") - await append_and_save_message(session_id, message) - logger.info(f"[STREAM] User message saved for session {session_id}") - # Create a task in the stream registry for reconnection support - turn_id = str(uuid4()) - log_meta["turn_id"] = turn_id - - session_create_start = time.perf_counter() - await stream_registry.create_session( - session_id=session_id, - user_id=user_id, - tool_call_id="chat_stream", - tool_name="chat", - turn_id=turn_id, - ) - logger.info( - f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", - extra={ - "json_fields": { - **log_meta, - "duration_ms": (time.perf_counter() - session_create_start) * 1000, - } - }, - ) - - # Per-turn stream is always fresh (unique turn_id), subscribe from beginning - subscribe_from_id = "0-0" - - await enqueue_copilot_turn( - session_id=session_id, - user_id=user_id, - message=request.message, - turn_id=turn_id, - is_user_message=request.is_user_message, - context=request.context, - file_ids=sanitized_file_ids, - ) + # Create a task in the stream registry for reconnection support. + # For duplicate messages, skip create_session entirely so the infra-retry + # client subscribes to the *existing* turn's Redis stream and receives the + # in-progress executor output rather than an empty stream. + turn_id = "" + if not is_duplicate_message: + turn_id = str(uuid4()) + log_meta["turn_id"] = turn_id + session_create_start = time.perf_counter() + await stream_registry.create_session( + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", + tool_name="chat", + turn_id=turn_id, + ) + logger.info( + f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "duration_ms": (time.perf_counter() - session_create_start) * 1000, + } + }, + ) + await enqueue_copilot_turn( + session_id=session_id, + user_id=user_id, + message=request.message, + turn_id=turn_id, + is_user_message=request.is_user_message, + context=request.context, + file_ids=sanitized_file_ids, + mode=request.mode, + model=request.model, + permissions=builder_permissions, + request_arrival_at=request_arrival_at, + ) + else: + logger.info( + f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue" + ) setup_time = (time.perf_counter() - stream_start_time) * 1000 logger.info( @@ -844,6 +1000,9 @@ async def stream_chat_post( extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, ) + # Per-turn stream is always fresh (unique turn_id), subscribe from beginning + subscribe_from_id = "0-0" + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: import time as time_module @@ -868,7 +1027,6 @@ async def stream_chat_post( if subscriber_queue is None: yield StreamFinish().to_sse() - yield "data: [DONE]\n\n" return # Read from the subscriber queue and yield to SSE @@ -898,7 +1056,6 @@ async def stream_chat_post( yield chunk.to_sse() - # Check for finish signal if isinstance(chunk, StreamFinish): total_time = time_module.perf_counter() - event_gen_start logger.info( @@ -913,6 +1070,7 @@ async def stream_chat_post( }, ) break + except asyncio.TimeoutError: yield StreamHeartbeat().to_sse() @@ -927,7 +1085,6 @@ async def stream_chat_post( } }, ) - pass # Client disconnected - background task continues except Exception as e: elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error( @@ -981,6 +1138,31 @@ async def stream_chat_post( ) +@router.get( + "/sessions/{session_id}/messages/pending", + response_model=PeekPendingMessagesResponse, + responses={ + 404: {"description": "Session not found or access denied"}, + }, +) +async def get_pending_messages( + session_id: str, + user_id: str = Security(auth.get_user_id), +): + """Peek at the pending-message buffer without consuming it. + + Returns the current contents of the session's pending message buffer + so the frontend can restore the queued-message indicator after a page + refresh and clear it correctly once a turn drains the buffer. + """ + await _validate_and_get_session(session_id, user_id) + pending = await peek_pending_messages(session_id) + return PeekPendingMessagesResponse( + messages=[m.content for m in pending], + count=len(pending), + ) + + @router.get( "/sessions/{session_id}/stream", ) @@ -1233,6 +1415,10 @@ ToolResponseUnion = ( | DocPageResponse | MCPToolsDiscoveredResponse | MCPToolOutputResponse + | MemoryStoreResponse + | MemorySearchResponse + | MemoryForgetCandidatesResponse + | MemoryForgetConfirmResponse ) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index b710bf7c57..11dac08084 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -9,10 +9,22 @@ import pytest import pytest_mock from backend.api.features.chat import routes as chat_routes +from backend.api.features.chat.routes import _strip_injected_context +from backend.copilot.rate_limit import SubscriptionTier +from backend.util.exceptions import NotFoundError app = fastapi.FastAPI() app.include_router(chat_routes.router) + +@app.exception_handler(NotFoundError) +async def _not_found_handler( + request: fastapi.Request, exc: NotFoundError +) -> fastapi.responses.JSONResponse: + """Mirror the production NotFoundError → 404 mapping from the REST app.""" + return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)}) + + client = fastapi.testclient.TestClient(app) TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" @@ -131,16 +143,23 @@ def test_stream_chat_rejects_too_many_file_ids(): assert response.status_code == 422 -def _mock_stream_internals(mocker: pytest_mock.MockFixture): +def _mock_stream_internals(mocker: pytest_mock.MockerFixture): """Mock the async internals of stream_chat_post so tests can exercise - validation and enrichment logic without needing Redis/RabbitMQ.""" + validation and enrichment logic without needing RabbitMQ. + + Returns: + A namespace with ``save`` and ``enqueue`` mock objects so + callers can make additional assertions about side-effects. + """ + import types + mocker.patch( "backend.api.features.chat.routes._validate_and_get_session", return_value=None, ) - mocker.patch( + mock_save = mocker.patch( "backend.api.features.chat.routes.append_and_save_message", - return_value=None, + return_value=MagicMock(), # non-None = message was saved (not a duplicate) ) mock_registry = mocker.MagicMock() mock_registry.create_session = mocker.AsyncMock(return_value=None) @@ -148,7 +167,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.stream_registry", mock_registry, ) - mocker.patch( + mock_enqueue = mocker.patch( "backend.api.features.chat.routes.enqueue_copilot_turn", return_value=None, ) @@ -156,14 +175,17 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.track_user_message", return_value=None, ) + return types.SimpleNamespace( + save=mock_save, enqueue=mock_enqueue, registry=mock_registry + ) -def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): +def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture): """Exactly 20 file_ids should be accepted (not rejected by validation).""" _mock_stream_internals(mocker) # Patch workspace lookup as imported by the routes module mocker.patch( - "backend.api.features.chat.routes.get_or_create_workspace", + "backend.data.workspace.get_or_create_workspace", return_value=type("W", (), {"id": "ws-1"})(), ) mock_prisma = mocker.MagicMock() @@ -184,15 +206,38 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): assert response.status_code == 200 +# ─── Duplicate message dedup ────────────────────────────────────────── + + +def test_stream_chat_skips_enqueue_for_duplicate_message( + mocker: pytest_mock.MockerFixture, +): + """When append_and_save_message returns None (duplicate detected), + enqueue_copilot_turn and stream_registry.create_session must NOT be called + to avoid double-processing and to prevent overwriting the active stream's + turn_id in Redis (which would cause reconnecting clients to miss the response).""" + mocks = _mock_stream_internals(mocker) + # Override save to return None — signalling a duplicate + mocks.save.return_value = None + + response = client.post( + "/sessions/sess-1/stream", + json={"message": "hello"}, + ) + assert response.status_code == 200 + mocks.enqueue.assert_not_called() + mocks.registry.create_session.assert_not_called() + + # ─── UUID format filtering ───────────────────────────────────────────── -def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): +def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture): """Non-UUID strings in file_ids should be silently filtered out and NOT passed to the database query.""" _mock_stream_internals(mocker) mocker.patch( - "backend.api.features.chat.routes.get_or_create_workspace", + "backend.data.workspace.get_or_create_workspace", return_value=type("W", (), {"id": "ws-1"})(), ) @@ -226,11 +271,11 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): # ─── Cross-workspace file_ids ───────────────────────────────────────── -def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): +def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture): """The batch query should scope to the user's workspace.""" _mock_stream_internals(mocker) mocker.patch( - "backend.api.features.chat.routes.get_or_create_workspace", + "backend.data.workspace.get_or_create_workspace", return_value=type("W", (), {"id": "my-workspace-id"})(), ) @@ -255,14 +300,14 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): # ─── Rate limit → 429 ───────────────────────────────────────────────── -def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture): """When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) # Ensure the rate-limit branch is entered by setting a non-zero limit. - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)), @@ -276,13 +321,15 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix assert "daily" in response.json()["detail"].lower() -def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_weekly_rate_limit( + mocker: pytest_mock.MockerFixture, +): """When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) resets_at = datetime.now(UTC) + timedelta(days=3) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", @@ -299,13 +346,13 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi assert "resets in" in detail -def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture): +def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): """The 429 response detail should include the human-readable reset time.""" from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded( @@ -331,14 +378,28 @@ def _mock_usage( *, daily_used: int = 500, weekly_used: int = 2000, + daily_limit: int = 10000, + weekly_limit: int = 50000, + tier: "SubscriptionTier" = SubscriptionTier.FREE, ) -> AsyncMock: - """Mock get_usage_status to return a predictable CoPilotUsageStatus.""" + """Mock get_usage_status and get_global_rate_limits for usage endpoint tests. + + Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and + ``get_usage_status`` so that tests exercise the endpoint without hitting + LaunchDarkly or Prisma. + """ from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(daily_limit, weekly_limit, tier), + ) + resets_at = datetime.now(UTC) + timedelta(days=1) status = CoPilotUsageStatus( - daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at), - weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at), + daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at), + weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at), ) return mocker.patch( "backend.api.features.chat.routes.get_usage_status", @@ -351,24 +412,35 @@ def test_usage_returns_daily_and_weekly( mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: - """GET /usage returns daily and weekly usage.""" + """GET /usage returns percentages for daily and weekly windows only. + + The raw used/limit microdollar values MUST NOT leak — clients should not + be able to derive per-turn cost or platform margins from the public API. + """ mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) response = client.get("/usage") assert response.status_code == 200 data = response.json() - assert data["daily"]["used"] == 500 - assert data["weekly"]["used"] == 2000 + # 500 / 10000 = 5%, 2000 / 50000 = 4% + assert data["daily"]["percent_used"] == 5.0 + assert data["weekly"]["percent_used"] == 4.0 + # Raw spend/limit must not be exposed. + assert "used" not in data["daily"] + assert "limit" not in data["daily"] + assert "used" not in data["weekly"] + assert "limit" not in data["weekly"] mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=10000, - weekly_token_limit=50000, + daily_cost_limit=10000, + weekly_cost_limit=50000, rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost, + tier=SubscriptionTier.FREE, ) @@ -376,11 +448,9 @@ def test_usage_uses_config_limits( mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: - """The endpoint forwards daily_token_limit and weekly_token_limit from config.""" - mock_get = _mock_usage(mocker) + """The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status.""" + mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777) - mocker.patch.object(chat_routes.config, "daily_token_limit", 99999) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777) mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500) response = client.get("/usage") @@ -388,9 +458,10 @@ def test_usage_uses_config_limits( assert response.status_code == 200 mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=99999, - weekly_token_limit=77777, + daily_cost_limit=99999, + weekly_cost_limit=77777, rate_limit_reset_cost=500, + tier=SubscriptionTier.FREE, ) @@ -526,3 +597,1207 @@ def test_create_session_rejects_nested_metadata( ) assert response.status_code == 422 + + +class TestStreamChatRequestModeValidation: + """Pydantic-level validation of the ``mode`` field on StreamChatRequest.""" + + def test_rejects_invalid_mode_value(self) -> None: + """Any string outside the Literal set must raise ValidationError.""" + from pydantic import ValidationError + + from backend.api.features.chat.routes import StreamChatRequest + + with pytest.raises(ValidationError): + StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type] + + def test_accepts_fast_mode(self) -> None: + from backend.api.features.chat.routes import StreamChatRequest + + req = StreamChatRequest(message="hi", mode="fast") + assert req.mode == "fast" + + def test_accepts_extended_thinking_mode(self) -> None: + from backend.api.features.chat.routes import StreamChatRequest + + req = StreamChatRequest(message="hi", mode="extended_thinking") + assert req.mode == "extended_thinking" + + def test_accepts_none_mode(self) -> None: + """``mode=None`` is valid (server decides via feature flags).""" + from backend.api.features.chat.routes import StreamChatRequest + + req = StreamChatRequest(message="hi", mode=None) + assert req.mode is None + + def test_mode_defaults_to_none_when_omitted(self) -> None: + from backend.api.features.chat.routes import StreamChatRequest + + req = StreamChatRequest(message="hi") + assert req.mode is None + + +# ─── POST /stream queue-fallback (when a turn is already in flight) ── + + +def _mock_stream_queue_internals( + mocker: pytest_mock.MockerFixture, + *, + session_exists: bool = True, + turn_in_flight: bool = True, + call_count: int = 1, +): + """Mock dependencies for the POST /stream queue-fallback path. + + When ``turn_in_flight`` is True the handler takes the 202 queue branch. + """ + if session_exists: + mock_session = mocker.MagicMock() + mock_session.id = "sess-1" + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + else: + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + side_effect=fastapi.HTTPException( + status_code=404, detail="Session not found." + ), + ) + mocker.patch( + "backend.api.features.chat.routes.is_turn_in_flight", + new_callable=AsyncMock, + return_value=turn_in_flight, + ) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 0, None), + ) + mocker.patch( + "backend.api.features.chat.routes.check_rate_limit", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.copilot.pending_message_helpers.get_redis_async", + new_callable=AsyncMock, + return_value=mocker.MagicMock(), + ) + mocker.patch( + "backend.copilot.pending_message_helpers.incr_with_ttl", + new_callable=AsyncMock, + return_value=call_count, + ) + mocker.patch( + "backend.copilot.pending_message_helpers.push_pending_message", + new_callable=AsyncMock, + return_value=1, + ) + # queue_user_message re-runs is_turn_in_flight via the helper module — + # stub that path out too so we don't need a fake stream_registry. + mocker.patch( + "backend.copilot.pending_message_helpers.get_active_session_meta", + new_callable=AsyncMock, + return_value=None, + ) + + +def test_stream_queue_returns_202_when_turn_in_flight( + mocker: pytest_mock.MockerFixture, +) -> None: + """Happy path: POST /stream to a session with a live turn → 202 queue.""" + _mock_stream_queue_internals(mocker) + + response = client.post( + "/sessions/sess-1/stream", + json={"message": "follow-up", "is_user_message": True}, + ) + + assert response.status_code == 202 + data = response.json() + assert data["buffer_length"] == 1 + assert "turn_in_flight" in data + + +def test_stream_queue_session_not_found_returns_404( + mocker: pytest_mock.MockerFixture, +) -> None: + """If the session doesn't exist or belong to the user, returns 404.""" + _mock_stream_queue_internals(mocker, session_exists=False) + + response = client.post( + "/sessions/bad-sess/stream", + json={"message": "hi", "is_user_message": True}, + ) + assert response.status_code == 404 + + +def test_stream_queue_call_frequency_limit_returns_429( + mocker: pytest_mock.MockerFixture, +) -> None: + """Per-user call-frequency cap rejects rapid-fire queued pushes.""" + from backend.copilot.pending_message_helpers import PENDING_CALL_LIMIT + + _mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1) + + response = client.post( + "/sessions/sess-1/stream", + json={"message": "hi", "is_user_message": True}, + ) + assert response.status_code == 429 + assert "Too many queued message requests this minute" in response.json()["detail"] + + +def test_stream_queue_converts_context_dict_to_pending_context( + mocker: pytest_mock.MockerFixture, +) -> None: + """StreamChatRequest.context is a raw dict; must be coerced to the + typed PendingMessageContext before being pushed onto the buffer.""" + _mock_stream_queue_internals(mocker) + queue_spy = mocker.patch( + "backend.copilot.pending_message_helpers.queue_user_message", + new_callable=AsyncMock, + ) + from backend.copilot.pending_message_helpers import QueuePendingMessageResponse + + queue_spy.return_value = QueuePendingMessageResponse( + buffer_length=1, max_buffer_length=10, turn_in_flight=True + ) + + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "hi", + "is_user_message": True, + "context": {"url": "https://example.test", "content": "body"}, + }, + ) + + assert response.status_code == 202 + queue_spy.assert_awaited_once() + kwargs = queue_spy.await_args.kwargs + from backend.copilot.pending_messages import PendingMessageContext + + assert isinstance(kwargs["context"], PendingMessageContext) + assert kwargs["context"].url == "https://example.test" + assert kwargs["context"].content == "body" + + +def test_stream_queue_passes_none_context_when_omitted( + mocker: pytest_mock.MockerFixture, +) -> None: + """When request.context is omitted, the queue call receives context=None.""" + _mock_stream_queue_internals(mocker) + queue_spy = mocker.patch( + "backend.copilot.pending_message_helpers.queue_user_message", + new_callable=AsyncMock, + ) + from backend.copilot.pending_message_helpers import QueuePendingMessageResponse + + queue_spy.return_value = QueuePendingMessageResponse( + buffer_length=1, max_buffer_length=10, turn_in_flight=True + ) + + response = client.post( + "/sessions/sess-1/stream", + json={"message": "hi", "is_user_message": True}, + ) + + assert response.status_code == 202 + queue_spy.assert_awaited_once() + assert queue_spy.await_args.kwargs["context"] is None + + +# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ───── + + +def test_get_pending_messages_returns_200_with_empty_buffer( + mocker: pytest_mock.MockerFixture, +) -> None: + """Happy path: no pending messages returns 200 with empty list.""" + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=mocker.MagicMock(), + ) + mocker.patch( + "backend.api.features.chat.routes.peek_pending_messages", + new_callable=AsyncMock, + return_value=[], + ) + + response = client.get("/sessions/sess-1/messages/pending") + + assert response.status_code == 200 + data = response.json() + assert data["messages"] == [] + assert data["count"] == 0 + + +def test_get_pending_messages_returns_queued_messages( + mocker: pytest_mock.MockerFixture, +) -> None: + """Returns pending messages from buffer without consuming them.""" + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=mocker.MagicMock(), + ) + mocker.patch( + "backend.api.features.chat.routes.peek_pending_messages", + new_callable=AsyncMock, + return_value=[ + MagicMock(content="first message"), + MagicMock(content="second message"), + ], + ) + + response = client.get("/sessions/sess-1/messages/pending") + + assert response.status_code == 200 + data = response.json() + assert data["count"] == 2 + assert data["messages"] == ["first message", "second message"] + + +def test_get_pending_messages_session_not_found_returns_404( + mocker: pytest_mock.MockerFixture, +) -> None: + """If session does not exist or belongs to another user, returns 404.""" + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + side_effect=fastapi.HTTPException(status_code=404, detail="Session not found."), + ) + + response = client.get("/sessions/bad-sess/messages/pending") + + assert response.status_code == 404 + + +class TestStripInjectedContext: + """Unit tests for `_strip_injected_context` — the GET-side helper that + hides the server-injected `` block from API responses. + + The strip is intentionally exact-match: it only removes the prefix the + inject helper writes (`...\\n\\n` at the very + start of the message). Any drift between writer and reader leaves the raw + block visible in the chat history, which is the failure mode this suite + documents. + """ + + @staticmethod + def _msg(role: str, content): + return {"role": role, "content": content} + + def test_strips_well_formed_prefix(self) -> None: + + original = "\nbiz ctx\n\n\nhello world" + result = _strip_injected_context(self._msg("user", original)) + assert result["content"] == "hello world" + + def test_passes_through_message_without_prefix(self) -> None: + + result = _strip_injected_context(self._msg("user", "just a question")) + assert result["content"] == "just a question" + + def test_only_strips_when_prefix_is_at_start(self) -> None: + """An embedded `` block later in the message must NOT + be stripped — only the leading prefix is server-injected.""" + + content = ( + "I copied this from somewhere: \nfoo\n\n\n" + ) + result = _strip_injected_context(self._msg("user", content)) + assert result["content"] == content + + def test_does_not_strip_with_only_single_newline_separator(self) -> None: + """The strip regex requires `\\n\\n` after the closing tag — a single + newline indicates a different format and must not be touched.""" + + content = "\nfoo\n\nhello" + result = _strip_injected_context(self._msg("user", content)) + assert result["content"] == content + + def test_assistant_messages_pass_through(self) -> None: + + original = "\nfoo\n\n\nhi" + result = _strip_injected_context(self._msg("assistant", original)) + assert result["content"] == original + + def test_non_string_content_passes_through(self) -> None: + """Multimodal / structured content (e.g. list of blocks) is not a + string and must not be touched by the strip helper.""" + + blocks = [{"type": "text", "text": "hello"}] + result = _strip_injected_context(self._msg("user", blocks)) + assert result["content"] is blocks + + def test_strip_with_multiline_understanding(self) -> None: + """The understanding payload spans multiple lines (markdown headings, + bullet points). `re.DOTALL` must allow the regex to span them.""" + + original = ( + "\n" + "# User Business Context\n\n" + "## User\nName: Alice\n\n" + "## Business\nCompany: Acme\n" + "\n\nactual question" + ) + result = _strip_injected_context(self._msg("user", original)) + assert result["content"] == "actual question" + + def test_strip_when_message_is_only_the_prefix(self) -> None: + """An empty user message gets injected with just the prefix; the + strip should yield an empty string.""" + + original = "\nctx\n\n\n" + result = _strip_injected_context(self._msg("user", original)) + assert result["content"] == "" + + def test_does_not_mutate_original_dict(self) -> None: + """The helper must return a copy — the original dict stays intact.""" + original_content = "\nctx\n\n\nhello" + msg = self._msg("user", original_content) + result = _strip_injected_context(msg) + assert result["content"] == "hello" + assert msg["content"] == original_content + assert result is not msg + + def test_no_role_field_does_not_crash(self) -> None: + + msg = {"content": "hello"} + result = _strip_injected_context(msg) + # Without a role, the helper short-circuits without touching content. + assert result["content"] == "hello" + + +# ─── message max_length validation ─────────────────────────────────── + + +def test_stream_chat_rejects_too_long_message(): + """A message exceeding max_length=64_000 must be rejected (422).""" + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_001, + }, + ) + assert response.status_code == 422 + + +def test_stream_chat_accepts_exactly_max_length_message( + mocker: pytest_mock.MockFixture, +): + """A message exactly at max_length=64_000 must be accepted.""" + _mock_stream_internals(mocker) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 0, SubscriptionTier.FREE), + ) + + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_000, + }, + ) + assert response.status_code == 200 + + +# ─── list_sessions ──────────────────────────────────────────────────── + + +def _make_session_info(session_id: str = "sess-1", title: str | None = "Test"): + """Build a minimal ChatSessionInfo-like mock.""" + from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata + + return ChatSessionInfo( + session_id=session_id, + user_id=TEST_USER_ID, + title=title, + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + metadata=ChatSessionMetadata(), + ) + + +def test_list_sessions_returns_sessions(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions returns list of sessions with is_processing=False when Redis OK.""" + session = _make_session_info("sess-abc") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + # Redis pipeline returns "done" (not "running") for this session + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["done"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert len(data["sessions"]) == 1 + assert data["sessions"][0]["id"] == "sess-abc" + assert data["sessions"][0]["is_processing"] is False + + +def test_list_sessions_marks_running_as_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Sessions with Redis status='running' should have is_processing=True.""" + session = _make_session_info("sess-xyz") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["running"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is True + + +def test_list_sessions_redis_failure_defaults_to_not_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Redis failures must be swallowed and sessions default to is_processing=False.""" + session = _make_session_info("sess-fallback") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + side_effect=Exception("Redis down"), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is False + + +def test_list_sessions_empty(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions with no sessions returns empty list without hitting Redis.""" + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([], 0), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 0 + assert data["sessions"] == [] + + +# ─── delete_session ─────────────────────────────────────────────────── + + +def test_delete_session_success(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 204 when deleted successfully.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=True, + ) + # Patch use_e2b_sandbox env-var to disable E2B so the route skips sandbox cleanup. + # Patching the Pydantic property directly doesn't work (Pydantic v2 intercepts + # attribute setting on BaseSettings instances and raises AttributeError). + mocker.patch.dict("os.environ", {"USE_E2B_SANDBOX": "false"}) + + response = client.delete("/sessions/sess-1") + + assert response.status_code == 204 + + +def test_delete_session_not_found(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 404 when session not found or not owned.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.delete("/sessions/sess-missing") + + assert response.status_code == 404 + + +# ─── cancel_session_task ────────────────────────────────────────────── + + +def _mock_validate_session( + mocker: pytest_mock.MockerFixture, *, session_id: str = "sess-1" +): + """Mock _validate_and_get_session to return a dummy session.""" + from backend.copilot.model import ChatSession + + dummy = ChatSession.new(TEST_USER_ID, dry_run=False) + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=dummy, + ) + + +def test_cancel_session_no_active_task(mocker: pytest_mock.MockerFixture) -> None: + """Cancel returns cancelled=True with reason when no stream is active.""" + _mock_validate_session(mocker) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + data = response.json() + assert data["cancelled"] is True + assert data["reason"] == "no_active_session" + + +def test_cancel_session_enqueues_cancel_and_confirms( + mocker: pytest_mock.MockerFixture, +) -> None: + """Cancel enqueues cancel task and returns cancelled=True once stream stops.""" + from backend.copilot.stream_registry import ActiveSession + + _mock_validate_session(mocker) + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + stopped_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="completed", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.get_session = AsyncMock(return_value=stopped_session) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + mock_enqueue = mocker.patch( + "backend.api.features.chat.routes.enqueue_cancel_task", + new_callable=AsyncMock, + ) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + assert response.json()["cancelled"] is True + mock_enqueue.assert_called_once_with("sess-1") + + +# ─── session_assign_user ────────────────────────────────────────────── + + +def test_session_assign_user(mocker: pytest_mock.MockerFixture) -> None: + """PATCH /sessions/{id}/assign-user calls assign_user_to_session and returns ok.""" + mock_assign = mocker.patch( + "backend.api.features.chat.routes.chat_service.assign_user_to_session", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.patch("/sessions/sess-1/assign-user") + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + mock_assign.assert_called_once_with("sess-1", TEST_USER_ID) + + +# ─── get_ttl_config ────────────────────────────────────────────────── + + +def test_get_ttl_config(mocker: pytest_mock.MockerFixture) -> None: + """GET /config/ttl returns correct TTL values derived from config.""" + mocker.patch.object(chat_routes.config, "stream_ttl", 300) + + response = client.get("/config/ttl") + + assert response.status_code == 200 + data = response.json() + assert data["stream_ttl_seconds"] == 300 + assert data["stream_ttl_ms"] == 300_000 + + +# ─── reset_copilot_usage ────────────────────────────────────────────── + + +def _mock_reset_internals( + mocker: pytest_mock.MockerFixture, + *, + cost: int = 100, + enable_credit: bool = True, + daily_limit: int = 10_000, + weekly_limit: int = 50_000, + tier: "SubscriptionTier" = SubscriptionTier.FREE, + daily_used: int = 10_001, + weekly_used: int = 1_000, + reset_count: int | None = 0, + acquire_lock: bool = True, + reset_daily: bool = True, + remaining_balance: int = 9_000, +): + """Set up all dependencies for reset_copilot_usage tests.""" + from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow + + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", cost) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", enable_credit) + + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(daily_limit, weekly_limit, tier), + ) + resets_at = datetime.now(UTC) + timedelta(hours=1) + status = CoPilotUsageStatus( + daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at), + weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at), + ) + mocker.patch( + "backend.api.features.chat.routes.get_usage_status", + new_callable=AsyncMock, + return_value=status, + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=reset_count, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=acquire_lock, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + mocker.patch( + "backend.api.features.chat.routes.reset_daily_usage", + new_callable=AsyncMock, + return_value=reset_daily, + ) + mocker.patch( + "backend.api.features.chat.routes.increment_daily_reset_count", + new_callable=AsyncMock, + ) + + mock_credit_model = MagicMock() + mock_credit_model.spend_credits = AsyncMock(return_value=remaining_balance) + mock_credit_model.top_up_credits = AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.get_user_credit_model", + new_callable=AsyncMock, + return_value=mock_credit_model, + ) + return mock_credit_model + + +def test_reset_usage_returns_400_when_cost_is_zero( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when rate_limit_reset_cost <= 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 0) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not available" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_credits_disabled( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when credit system is disabled.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", False) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "disabled" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_no_daily_limit( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily_limit is 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "nothing to reset" in response.json()["detail"].lower() + + +def test_reset_usage_returns_503_when_redis_unavailable( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 when Redis is unavailable for reset count.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + + +def test_reset_usage_returns_429_when_max_resets_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when max daily resets exceeded.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 2) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "resets" in response.json()["detail"].lower() + + +def test_reset_usage_returns_429_when_lock_not_acquired( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when a concurrent reset is in progress.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "in progress" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_limit_not_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily limit has not been reached.""" + _mock_reset_internals(mocker, daily_used=500, daily_limit=10_000) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not reached" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_weekly_also_exhausted( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when weekly limit is also exhausted.""" + _mock_reset_internals( + mocker, + daily_used=10_001, + daily_limit=10_000, + weekly_used=50_001, + weekly_limit=50_000, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "weekly" in response.json()["detail"].lower() + + +def test_reset_usage_returns_402_when_insufficient_credits( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 402 when credits are insufficient.""" + from backend.util.exceptions import InsufficientBalanceError + + mock_credit = _mock_reset_internals(mocker) + mock_credit.spend_credits = AsyncMock( + side_effect=InsufficientBalanceError( + message="Insufficient balance", + user_id=TEST_USER_ID, + balance=0.0, + amount=100.0, + ) + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 402 + + +def test_reset_usage_success(mocker: pytest_mock.MockerFixture) -> None: + """POST /usage/reset returns 200 with updated usage on success.""" + _mock_reset_internals(mocker, remaining_balance=8_900) + + response = client.post("/usage/reset") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["credits_charged"] == 100 + assert data["remaining_balance"] == 8_900 + assert "daily" in data["usage"] + assert "weekly" in data["usage"] + + +def test_reset_usage_refunds_on_redis_failure( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 and refunds credits when Redis reset fails.""" + mock_credit = _mock_reset_internals(mocker, reset_daily=False) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + # Credits should be refunded via top_up_credits + mock_credit.top_up_credits.assert_called_once() + + +# ─── resume_session_stream ─────────────────────────────────────────── + + +def test_resume_session_stream_no_active_session( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when no active session.""" + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + +def test_resume_session_stream_no_subscriber_queue( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when subscribe_to_session returns None.""" + from backend.copilot.stream_registry import ActiveSession + + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.subscribe_to_session = AsyncMock(return_value=None) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + +# ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── + + +def test_disconnect_stream_returns_204_and_awaits_registry( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mock_session = MagicMock() + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.delete("/sessions/sess-1/stream") + + assert response.status_code == 204 + mock_disconnect.assert_awaited_once_with("sess-1") + + +def test_disconnect_stream_returns_404_when_session_missing( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=None, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + ) + + response = client.delete("/sessions/unknown-session/stream") + + assert response.status_code == 404 + mock_disconnect.assert_not_awaited() + + +# ─── GET /sessions/{session_id} — backward pagination ───────────────────────── + + +def _make_paginated_messages( + mocker: pytest_mock.MockerFixture, *, has_more: bool = False +): + """Return a mock PaginatedMessages and configure the DB patch.""" + from datetime import UTC, datetime + + from backend.copilot.db import PaginatedMessages + from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata + + now = datetime.now(UTC) + session_info = ChatSessionInfo( + session_id="sess-1", + user_id=TEST_USER_ID, + usage=[], + started_at=now, + updated_at=now, + metadata=ChatSessionMetadata(), + ) + page = PaginatedMessages( + messages=[ChatMessage(role="user", content="hello", sequence=0)], + has_more=has_more, + oldest_sequence=0, + session=session_info, + ) + mock_paginate = mocker.patch( + "backend.api.features.chat.routes.get_chat_messages_paginated", + new_callable=AsyncMock, + return_value=page, + ) + return page, mock_paginate + + +def test_get_session_returns_backward_paginated( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """All sessions use backward (newest-first) pagination.""" + _make_paginated_messages(mocker) + mocker.patch( + "backend.api.features.chat.routes.stream_registry.get_active_session", + new_callable=AsyncMock, + return_value=(None, None), + ) + + response = client.get("/sessions/sess-1") + + assert response.status_code == 200 + data = response.json() + assert data["oldest_sequence"] == 0 + assert "forward_paginated" not in data + assert "newest_sequence" not in data + + +# ─── POST /sessions with builder_graph_id (get-or-create) ────────────── + + +def test_create_session_with_builder_graph_id_uses_get_or_create( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``POST /sessions`` with ``builder_graph_id`` routes through + ``get_or_create_builder_session`` and returns a session bound to the graph.""" + from backend.copilot.model import ChatSession + + async def _fake_get_or_create(user_id: str, graph_id: str) -> ChatSession: + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=graph_id, + ) + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-1"}) + + assert response.status_code == 200 + body = response.json() + assert body["metadata"]["builder_graph_id"] == "graph-1" + assert body["metadata"]["dry_run"] is False + + +def test_create_session_with_builder_graph_id_returns_404_when_not_owned( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``get_or_create_builder_session`` raises ``NotFoundError`` when the + user doesn't own the graph; the route must map that to HTTP 404.""" + + async def _fake_get_or_create(user_id: str, graph_id: str): + raise NotFoundError(f"Graph {graph_id} not found") + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-unauthorized"}) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +def test_create_session_without_builder_graph_id_creates_fresh( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """With no ``builder_graph_id`` the endpoint falls through to the + default ``create_chat_session`` path — no get-or-create lookup.""" + from backend.copilot.model import ChatSession + + gorc = mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + ) + + async def _fake_create(user_id: str, *, dry_run: bool) -> ChatSession: + return ChatSession.new(user_id, dry_run=dry_run) + + mocker.patch( + "backend.api.features.chat.routes.create_chat_session", + new_callable=AsyncMock, + side_effect=_fake_create, + ) + + response = client.post("/sessions", json={"dry_run": True}) + + assert response.status_code == 200 + assert response.json()["metadata"]["dry_run"] is True + gorc.assert_not_called() + + +def test_create_session_rejects_unknown_fields( + test_user_id: str, +) -> None: + """Extra request fields are rejected (422) to prevent silent mis-use.""" + response = client.post("/sessions", json={"unexpected": "x"}) + assert response.status_code == 422 + + +def test_resolve_session_permissions_blocks_out_of_scope_tools() -> None: + """Builder-bound sessions return a blacklist of the three tools that + conflict with the panel's graph-bound scope. Regular sessions return + ``None`` so default (unrestricted) behaviour is preserved.""" + from backend.copilot.builder_context import BUILDER_BLOCKED_TOOLS + from backend.copilot.model import ChatSession + + unbound = ChatSession.new("u1", dry_run=False) + assert chat_routes.resolve_session_permissions(unbound) is None + + bound = ChatSession.new("u1", dry_run=False, builder_graph_id="g1") + perms = chat_routes.resolve_session_permissions(bound) + assert perms is not None + assert perms.tools_exclude is True # blacklist, not whitelist + assert sorted(perms.tools) == sorted(BUILDER_BLOCKED_TOOLS) + # Read-side lookups stay available — only write-scope / guide-dup are blocked. + assert "find_block" not in perms.tools + assert "find_agent" not in perms.tools + assert "search_docs" not in perms.tools + # The write tools (edit_agent / run_agent) are NOT blacklisted — they + # enforce scope per-tool via the builder_graph_id guard. + assert "edit_agent" not in perms.tools + assert "run_agent" not in perms.tools diff --git a/autogpt_platform/backend/backend/api/features/library/_add_to_library.py b/autogpt_platform/backend/backend/api/features/library/_add_to_library.py index 243ec1c0d8..e77e22c7f5 100644 --- a/autogpt_platform/backend/backend/api/features/library/_add_to_library.py +++ b/autogpt_platform/backend/backend/api/features/library/_add_to_library.py @@ -12,6 +12,7 @@ import prisma.models import backend.api.features.library.model as library_model import backend.data.graph as graph_db +from backend.api.features.library.db import _fetch_schedule_info from backend.data.graph import GraphModel, GraphSettings from backend.data.includes import library_agent_include from backend.util.exceptions import NotFoundError @@ -117,4 +118,5 @@ async def add_graph_to_library( f"for store listing version #{store_listing_version_id} " f"to library for user #{user_id}" ) - return library_model.LibraryAgent.from_db(added_agent) + schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id) + return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info) diff --git a/autogpt_platform/backend/backend/api/features/library/_add_to_library_test.py b/autogpt_platform/backend/backend/api/features/library/_add_to_library_test.py index 4d4ae9bdcd..dbb8a17626 100644 --- a/autogpt_platform/backend/backend/api/features/library/_add_to_library_test.py +++ b/autogpt_platform/backend/backend/api/features/library/_add_to_library_test.py @@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None: "backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db", return_value=converted_agent, ) as mock_from_db, + patch( + "backend.api.features.library._add_to_library._fetch_schedule_info", + new=AsyncMock(return_value={}), + ), ): mock_prisma.return_value.create = AsyncMock(return_value=created_agent) result = await add_graph_to_library("slv-id", graph_model, "user-id") assert result is converted_agent - mock_from_db.assert_called_once_with(created_agent) + mock_from_db.assert_called_once_with(created_agent, schedule_info={}) # Verify create was called with correct data create_call = mock_prisma.return_value.create.call_args create_data = create_call.kwargs["data"] @@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None: "backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db", return_value=converted_agent, ) as mock_from_db, + patch( + "backend.api.features.library._add_to_library._fetch_schedule_info", + new=AsyncMock(return_value={}), + ), ): mock_prisma.return_value.create = AsyncMock( side_effect=prisma.errors.UniqueViolationError( @@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None: result = await add_graph_to_library("slv-id", graph_model, "user-id") assert result is converted_agent - mock_from_db.assert_called_once_with(updated_agent) + mock_from_db.assert_called_once_with(updated_agent, schedule_info={}) # Verify update was called with correct where and data update_call = mock_prisma.return_value.update.call_args assert update_call.kwargs["where"] == { diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index fcfc896ea2..0743b461c6 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -1,6 +1,7 @@ import asyncio import itertools import logging +from datetime import datetime, timezone from typing import Literal, Optional import fastapi @@ -43,6 +44,65 @@ config = Config() integration_creds_manager = IntegrationCredentialsManager() +async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]: + """Fetch execution counts per graph in a single batched query.""" + if not graph_ids: + return {} + rows = await prisma.models.AgentGraphExecution.prisma().group_by( + by=["agentGraphId"], + where={ + "userId": user_id, + "agentGraphId": {"in": graph_ids}, + "isDeleted": False, + }, + count=True, + ) + return { + row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0) + for row in rows + } + + +async def _fetch_schedule_info( + user_id: str, graph_id: Optional[str] = None +) -> dict[str, str]: + """Fetch a map of graph_id → earliest next_run_time ISO string. + + When `graph_id` is provided, the scheduler query is narrowed to that graph, + which is cheaper for single-agent lookups (detail page, post-update, etc.). + """ + try: + scheduler_client = get_scheduler_client() + schedules = await scheduler_client.get_execution_schedules( + graph_id=graph_id, + user_id=user_id, + ) + earliest: dict[str, tuple[datetime, str]] = {} + for s in schedules: + parsed = _parse_iso_datetime(s.next_run_time) + if parsed is None: + continue + current = earliest.get(s.graph_id) + if current is None or parsed < current[0]: + earliest[s.graph_id] = (parsed, s.next_run_time) + return {graph_id: iso for graph_id, (_, iso) in earliest.items()} + except Exception: + logger.warning("Failed to fetch schedules for library agents", exc_info=True) + return {} + + +def _parse_iso_datetime(value: str) -> Optional[datetime]: + """Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC).""" + try: + parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError: + logger.warning("Failed to parse schedule next_run_time: %s", value) + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed + + async def list_library_agents( user_id: str, search_term: Optional[str] = None, @@ -137,12 +197,22 @@ async def list_library_agents( logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}") + graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId] + execution_counts, schedule_info = await asyncio.gather( + _fetch_execution_counts(user_id, graph_ids), + _fetch_schedule_info(user_id), + ) + # Only pass valid agents to the response valid_library_agents: list[library_model.LibraryAgent] = [] for agent in library_agents: try: - library_agent = library_model.LibraryAgent.from_db(agent) + library_agent = library_model.LibraryAgent.from_db( + agent, + execution_count_override=execution_counts.get(agent.agentGraphId), + schedule_info=schedule_info, + ) valid_library_agents.append(library_agent) except Exception as e: # Skip this agent if there was an error @@ -214,12 +284,22 @@ async def list_favorite_library_agents( f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}" ) + graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId] + execution_counts, schedule_info = await asyncio.gather( + _fetch_execution_counts(user_id, graph_ids), + _fetch_schedule_info(user_id), + ) + # Only pass valid agents to the response valid_library_agents: list[library_model.LibraryAgent] = [] for agent in library_agents: try: - library_agent = library_model.LibraryAgent.from_db(agent) + library_agent = library_model.LibraryAgent.from_db( + agent, + execution_count_override=execution_counts.get(agent.agentGraphId), + schedule_info=schedule_info, + ) valid_library_agents.append(library_agent) except Exception as e: # Skip this agent if there was an error @@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent where={"userId": store_listing.owningUserId} ) + schedule_info = ( + await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id) + if library_agent.AgentGraph + else {} + ) + return library_model.LibraryAgent.from_db( library_agent, sub_graphs=( @@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent ), store_listing=store_listing, profile=profile, + schedule_info=schedule_info, ) @@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id( }, include=library_agent_include(user_id), ) - return library_model.LibraryAgent.from_db(agent) if agent else None + if not agent: + return None + schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId) + return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info) async def get_library_agent_by_graph_id( @@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id( assert agent.AgentGraph # make type checker happy # Include sub-graphs so we can make a full credentials input schema sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph) - return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs) + schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId) + return library_model.LibraryAgent.from_db( + agent, sub_graphs=sub_graphs, schedule_info=schedule_info + ) async def add_generated_agent_image( @@ -500,7 +593,11 @@ async def create_library_agent( for agent, graph in zip(library_agents, graph_entries): asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id)) - return [library_model.LibraryAgent.from_db(agent) for agent in library_agents] + schedule_info = await _fetch_schedule_info(user_id) + return [ + library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info) + for agent in library_agents + ] async def update_agent_version_in_library( @@ -562,7 +659,8 @@ async def update_agent_version_in_library( f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}" ) - return library_model.LibraryAgent.from_db(lib) + schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id) + return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info) async def create_graph_in_library( @@ -645,6 +743,7 @@ async def update_library_agent_version_and_settings( graph=agent_graph, hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, + builder_chat_session_id=library.settings.builder_chat_session_id, ) if updated_settings != library.settings: library = await update_library_agent( @@ -1467,7 +1566,11 @@ async def bulk_move_agents_to_folder( ), ) - return [library_model.LibraryAgent.from_db(agent) for agent in agents] + schedule_info = await _fetch_schedule_info(user_id) + return [ + library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info) + for agent in agents + ] def collect_tree_ids( diff --git a/autogpt_platform/backend/backend/api/features/library/db_test.py b/autogpt_platform/backend/backend/api/features/library/db_test.py index 5e3e36ac63..562a0bfdfd 100644 --- a/autogpt_platform/backend/backend/api/features/library/db_test.py +++ b/autogpt_platform/backend/backend/api/features/library/db_test.py @@ -65,6 +65,11 @@ async def test_get_library_agents(mocker): ) mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={}), + ) + # Call function result = await db.list_library_agents("test-user") @@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert(): # Verify update branch restores soft-deleted/archived agents assert data["update"]["isDeleted"] is False assert data["update"]["isArchived"] is False + + +@pytest.mark.asyncio +async def test_list_favorite_library_agents(mocker): + mock_library_agents = [ + prisma.models.LibraryAgent( + id="fav1", + userId="test-user", + agentGraphId="agent-fav", + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=False, + isDeleted=False, + isArchived=False, + createdAt=datetime.now(), + updatedAt=datetime.now(), + isFavorite=True, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id="agent-fav", + version=1, + name="Favorite Agent", + description="My Favorite", + userId="other-user", + isActive=True, + createdAt=datetime.now(), + ), + ) + ] + + mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma") + mock_library_agent.return_value.find_many = mocker.AsyncMock( + return_value=mock_library_agents + ) + mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={"agent-fav": 7}), + ) + + result = await db.list_favorite_library_agents("test-user") + + assert len(result.agents) == 1 + assert result.agents[0].id == "fav1" + assert result.agents[0].name == "Favorite Agent" + assert result.agents[0].graph_id == "agent-fav" + assert result.pagination.total_items == 1 + assert result.pagination.total_pages == 1 + assert result.pagination.current_page == 1 + assert result.pagination.page_size == 50 + + +@pytest.mark.asyncio +async def test_list_library_agents_skips_failed_agent(mocker): + """Agents that fail parsing should be skipped — covers the except branch.""" + mock_library_agents = [ + prisma.models.LibraryAgent( + id="ua-bad", + userId="test-user", + agentGraphId="agent-bad", + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=False, + isDeleted=False, + isArchived=False, + createdAt=datetime.now(), + updatedAt=datetime.now(), + isFavorite=False, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id="agent-bad", + version=1, + name="Bad Agent", + description="", + userId="other-user", + isActive=True, + createdAt=datetime.now(), + ), + ) + ] + + mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma") + mock_library_agent.return_value.find_many = mocker.AsyncMock( + return_value=mock_library_agents + ) + mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={}), + ) + mocker.patch( + "backend.api.features.library.model.LibraryAgent.from_db", + side_effect=Exception("parse error"), + ) + + result = await db.list_library_agents("test-user") + + assert len(result.agents) == 0 + assert result.pagination.total_items == 1 + + +@pytest.mark.asyncio +async def test_fetch_execution_counts_empty_graph_ids(): + result = await db._fetch_execution_counts("user-1", []) + assert result == {} + + +@pytest.mark.asyncio +async def test_fetch_execution_counts_uses_group_by(mocker): + mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma") + mock_prisma.return_value.group_by = mocker.AsyncMock( + return_value=[ + {"agentGraphId": "graph-1", "_count": {"_all": 5}}, + {"agentGraphId": "graph-2", "_count": {"_all": 2}}, + ] + ) + + result = await db._fetch_execution_counts( + "user-1", ["graph-1", "graph-2", "graph-3"] + ) + + assert result == {"graph-1": 5, "graph-2": 2} + mock_prisma.return_value.group_by.assert_called_once_with( + by=["agentGraphId"], + where={ + "userId": "user-1", + "agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]}, + "isDeleted": False, + }, + count=True, + ) diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index 7211a7ebfe..8bd4a9edab 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel): folder_name: str | None = None # Denormalized for display recommended_schedule_cron: str | None = None + is_scheduled: bool = pydantic.Field( + default=False, + description="Whether this agent has active execution schedules", + ) + next_scheduled_run: str | None = pydantic.Field( + default=None, + description="ISO 8601 timestamp of the next scheduled run, if any", + ) settings: GraphSettings = pydantic.Field(default_factory=GraphSettings) marketplace_listing: Optional["MarketplaceListing"] = None @@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel): sub_graphs: Optional[list[prisma.models.AgentGraph]] = None, store_listing: Optional[prisma.models.StoreListing] = None, profile: Optional[prisma.models.Profile] = None, + execution_count_override: Optional[int] = None, + schedule_info: Optional[dict[str, str]] = None, ) -> "LibraryAgent": """ Factory method that constructs a LibraryAgent from a Prisma LibraryAgent @@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel): status = status_result.status new_output = status_result.new_output - execution_count = len(executions) + execution_count = ( + execution_count_override + if execution_count_override is not None + else len(executions) + ) success_rate: float | None = None avg_correctness_score: float | None = None - if execution_count > 0: + if executions and execution_count > 0: success_count = sum( 1 for e in executions @@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel): folder_id=agent.folderId, folder_name=agent.Folder.name if agent.Folder else None, recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron, + is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info), + next_scheduled_run=( + schedule_info.get(agent.agentGraphId) if schedule_info else None + ), settings=_parse_settings(agent.settings), marketplace_listing=marketplace_listing_data, ) diff --git a/autogpt_platform/backend/backend/api/features/library/model_test.py b/autogpt_platform/backend/backend/api/features/library/model_test.py index a32b19322d..31924a1793 100644 --- a/autogpt_platform/backend/backend/api/features/library/model_test.py +++ b/autogpt_platform/backend/backend/api/features/library/model_test.py @@ -1,11 +1,66 @@ import datetime +import prisma.enums import prisma.models import pytest from . import model as library_model +def _make_library_agent( + *, + graph_id: str = "g1", + executions: list | None = None, +) -> prisma.models.LibraryAgent: + return prisma.models.LibraryAgent( + id="la1", + userId="u1", + agentGraphId=graph_id, + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=True, + isDeleted=False, + isArchived=False, + createdAt=datetime.datetime.now(), + updatedAt=datetime.datetime.now(), + isFavorite=False, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id=graph_id, + version=1, + name="Agent", + description="Desc", + userId="u1", + isActive=True, + createdAt=datetime.datetime.now(), + Executions=executions, + ), + ) + + +def test_from_db_execution_count_override_covers_success_rate(): + """Covers execution_count_override is not None branch and executions/count > 0 block.""" + now = datetime.datetime.now(datetime.timezone.utc) + exec1 = prisma.models.AgentGraphExecution( + id="exec-1", + agentGraphId="g1", + agentGraphVersion=1, + userId="u1", + executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED, + createdAt=now, + updatedAt=now, + isDeleted=False, + isShared=False, + ) + agent = _make_library_agent(executions=[exec1]) + + result = library_model.LibraryAgent.from_db(agent, execution_count_override=1) + + assert result.execution_count == 1 + assert result.success_rate is not None + assert result.success_rate == 100.0 + + @pytest.mark.asyncio async def test_agent_preset_from_db(test_user_id: str): # Create mock DB agent diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py new file mode 100644 index 0000000000..7764686098 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py @@ -0,0 +1 @@ +"""Platform bot linking — user-facing REST routes.""" diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py new file mode 100644 index 0000000000..7b0f845c01 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py @@ -0,0 +1,158 @@ +"""User-facing platform_linking REST routes (JWT auth).""" + +import logging +from typing import Annotated + +from autogpt_libs import auth +from fastapi import APIRouter, HTTPException, Path, Security + +from backend.data.db_accessors import platform_linking_db +from backend.platform_linking.models import ( + ConfirmLinkResponse, + ConfirmUserLinkResponse, + DeleteLinkResponse, + LinkTokenInfoResponse, + PlatformLinkInfo, + PlatformUserLinkInfo, +) +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +TokenPath = Annotated[ + str, + Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"), +] + + +def _translate(exc: Exception) -> HTTPException: + if isinstance(exc, NotFoundError): + return HTTPException(status_code=404, detail=str(exc)) + if isinstance(exc, NotAuthorizedError): + return HTTPException(status_code=403, detail=str(exc)) + if isinstance(exc, LinkAlreadyExistsError): + return HTTPException(status_code=409, detail=str(exc)) + if isinstance(exc, LinkTokenExpiredError): + return HTTPException(status_code=410, detail=str(exc)) + if isinstance(exc, LinkFlowMismatchError): + return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=500, detail="Internal error.") + + +@router.get( + "/tokens/{token}/info", + response_model=LinkTokenInfoResponse, + dependencies=[Security(auth.requires_user)], + summary="Get display info for a link token", +) +async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse: + try: + return await platform_linking_db().get_link_token_info(token) + except (NotFoundError, LinkTokenExpiredError) as exc: + raise _translate(exc) from exc + + +@router.post( + "/tokens/{token}/confirm", + response_model=ConfirmLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a SERVER link token (user must be authenticated)", +) +async def confirm_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmLinkResponse: + try: + return await platform_linking_db().confirm_server_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.post( + "/user-tokens/{token}/confirm", + response_model=ConfirmUserLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a USER link token (user must be authenticated)", +) +async def confirm_user_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmUserLinkResponse: + try: + return await platform_linking_db().confirm_user_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.get( + "/links", + response_model=list[PlatformLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all platform servers linked to the authenticated user", +) +async def list_my_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformLinkInfo]: + return await platform_linking_db().list_server_links(user_id) + + +@router.get( + "/user-links", + response_model=list[PlatformUserLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all DM links for the authenticated user", +) +async def list_my_user_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformUserLinkInfo]: + return await platform_linking_db().list_user_links(user_id) + + +@router.delete( + "/links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a platform server", +) +async def delete_link( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_server_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc + + +@router.delete( + "/user-links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a DM / user link", +) +async def delete_user_link_route( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_user_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py new file mode 100644 index 0000000000..944ef8eb6a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py @@ -0,0 +1,264 @@ +"""Route tests: domain exceptions → HTTPException status codes.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + + +def _db_mock(**method_configs): + """Return a mock of the accessor's return value with the given AsyncMocks.""" + db = MagicMock() + for name, mock in method_configs.items(): + setattr(db, name, mock) + return db + + +class TestTokenInfoRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_expired_maps_to_410(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 410 + + +class TestConfirmLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_link_token + + db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestConfirmUserLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_user_link_token + + db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_user_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestDeleteLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +class TestDeleteUserLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing"))) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock( + delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +# ── Adversarial: malformed token path params ────────────────────────── + + +class TestAdversarialTokenPath: + # TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64. + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_rejects_token_with_special_chars(self, client): + response = client.get("/api/platform-linking/tokens/bad%24token/info") + assert response.status_code == 422 + + def test_rejects_token_with_path_traversal(self, client): + for probe in ("..%2F..", "foo..bar", "foo%2Fbar"): + response = client.get(f"/api/platform-linking/tokens/{probe}/info") + assert response.status_code in ( + 404, + 422, + ), f"path-traversal probe {probe!r} returned {response.status_code}" + + def test_rejects_token_too_long(self, client): + long_token = "a" * 65 + response = client.get(f"/api/platform-linking/tokens/{long_token}/info") + assert response.status_code == 422 + + def test_accepts_token_at_max_length(self, client): + token = "a" * 64 + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get(f"/api/platform-linking/tokens/{token}/info") + assert response.status_code == 404 + + def test_accepts_urlsafe_b64_token_shape(self, client): + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info") + assert response.status_code == 404 + + def test_confirm_rejects_malformed_token(self, client): + response = client.post("/api/platform-linking/tokens/bad%24token/confirm") + assert response.status_code == 422 + + +class TestAdversarialDeleteLinkId: + """DELETE link_id has no regex — ensure weird values are handled via + NotFoundError (no crash, no cross-user leak).""" + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_weird_link_id_returns_404(self, client): + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""): + response = client.delete(f"/api/platform-linking/links/{link_id}") + assert response.status_code in (404, 405) diff --git a/autogpt_platform/backend/backend/api/features/store/db_test.py b/autogpt_platform/backend/backend/api/features/store/db_test.py index 35946b8980..f3acd867d3 100644 --- a/autogpt_platform/backend/backend/api/features/store/db_test.py +++ b/autogpt_platform/backend/backend/api/features/store/db_test.py @@ -189,6 +189,7 @@ async def test_create_store_submission(mocker): notifyOnAgentApproved=True, notifyOnAgentRejected=True, timezone="Europe/Delft", + subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue] ) mock_agent = prisma.models.AgentGraph( id="agent-id", diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py new file mode 100644 index 0000000000..96fd8763eb --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -0,0 +1,1120 @@ +"""Tests for subscription tier API endpoints.""" + +from unittest.mock import AsyncMock, Mock + +import fastapi +import fastapi.testclient +import pytest +import pytest_mock +import stripe +from autogpt_libs.auth.jwt_utils import get_jwt_payload +from prisma.enums import SubscriptionTier + +from .v1 import _validate_checkout_redirect_url, v1_router + +TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" +TEST_FRONTEND_ORIGIN = "https://app.example.com" + + +@pytest.fixture() +def client() -> fastapi.testclient.TestClient: + """Fresh FastAPI app + client per test with auth override applied. + + Using a fixture avoids the leaky global-app + try/finally teardown pattern: + if a test body raises before teardown_auth runs, dependency overrides were + previously leaking into subsequent tests. + """ + app = fastapi.FastAPI() + app.include_router(v1_router) + + def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]: + return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"} + + app.dependency_overrides[get_jwt_payload] = override_get_jwt_payload + try: + yield fastapi.testclient.TestClient(app) + finally: + app.dependency_overrides.clear() + + +@pytest.fixture(autouse=True) +def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None: + """Pin the configured frontend origin used by the open-redirect guard.""" + from backend.api.features import v1 as v1_mod + + mocker.patch.object( + v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN + ) + + +@pytest.fixture(autouse=True) +def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None: + """Default pending-change lookup to None so tests don't hit Stripe/DB. + + Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + +@pytest.fixture(autouse=True) +def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None: + """Stub Stripe price + proration lookups used by get_subscription_status. + + The POST /credits/subscription handler now returns the full subscription + status payload from every branch (same-tier, FREE downgrade, paid→paid + modify, checkout creation), so every POST test implicitly hits these + helpers. Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + +@pytest.mark.parametrize( + "url,expected", + [ + # Valid URLs matching the configured frontend origin + (f"{TEST_FRONTEND_ORIGIN}/success", True), + (f"{TEST_FRONTEND_ORIGIN}/cancel?ref=abc", True), + # Wrong origin + ("https://evil.example.org/phish", False), + ("https://evil.example.org", False), + # @ in URL (user:pass@host attack) + (f"https://attacker.example.com@{TEST_FRONTEND_ORIGIN}/ok", False), + # Backslash normalisation attack + (f"https:{TEST_FRONTEND_ORIGIN}\\@attacker.example.com/ok", False), + # javascript: scheme + ("javascript:alert(1)", False), + # Empty string + ("", False), + # Control character (U+0000) in URL + (f"{TEST_FRONTEND_ORIGIN}/ok\x00evil", False), + # Non-http scheme + (f"ftp://{TEST_FRONTEND_ORIGIN}/ok", False), + ], +) +def test_validate_checkout_redirect_url( + url: str, + expected: bool, + mocker: pytest_mock.MockFixture, +) -> None: + """_validate_checkout_redirect_url rejects adversarial inputs.""" + from backend.api.features import v1 as v1_mod + + mocker.patch.object( + v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN + ) + assert _validate_checkout_redirect_url(url) is expected + + +def test_get_subscription_status_pro( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription returns PRO tier with Stripe price for a PRO user.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro" if tier == SubscriptionTier.PRO else None + + async def mock_stripe_price_amount(price_id: str) -> int: + return 1999 if price_id == "price_pro" else 0 + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1._get_stripe_price_amount", + side_effect=mock_stripe_price_amount, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=500, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + assert data["monthly_cost"] == 1999 + assert data["tier_costs"]["PRO"] == 1999 + assert data["tier_costs"]["BUSINESS"] == 0 + assert data["tier_costs"]["FREE"] == 0 + assert data["proration_credit_cents"] == 500 + + +def test_get_subscription_status_defaults_to_free( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription when subscription_tier is None defaults to FREE.""" + mock_user = Mock() + mock_user.subscription_tier = None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == SubscriptionTier.FREE.value + assert data["monthly_cost"] == 0 + assert data["tier_costs"] == { + "FREE": 0, + "PRO": 0, + "BUSINESS": 0, + "ENTERPRISE": 0, + } + assert data["proration_credit_cents"] == 0 + + +def test_get_subscription_status_stripe_error_falls_back_to_zero( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription returns cost=0 when Stripe price fetch fails (returns None). + + _get_stripe_price_amount returns None on StripeError so the error state is + not cached. The endpoint must treat None as 0 — not raise or return invalid data. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro" if tier == SubscriptionTier.PRO else None + + async def mock_stripe_price_amount_none(price_id: str) -> None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1._get_stripe_price_amount", + side_effect=mock_stripe_price_amount_none, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + # When Stripe returns None, cost falls back to 0 + assert data["monthly_cost"] == 0 + assert data["tier_costs"]["PRO"] == 0 + + +def test_update_subscription_tier_free_no_payment( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription to FREE tier when payment disabled skips Stripe.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_disabled(*args, **kwargs): + return False + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) + mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + assert response.json()["url"] == "" + + +def test_update_subscription_tier_paid_beta_user( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for paid tier when payment disabled returns 422.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_disabled(*args, **kwargs): + return False + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) + + response = client.post("/credits/subscription", json={"tier": "PRO"}) + + assert response.status_code == 422 + assert "not available" in response.json()["detail"] + + +def test_update_subscription_tier_paid_requires_urls( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for paid tier without success/cancel URLs returns 422.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "PRO"}) + + assert response.status_code == 422 + + +def test_update_subscription_tier_creates_checkout( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription creates Stripe Checkout Session for paid upgrade.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + return_value="https://checkout.stripe.com/pay/cs_test_abc", + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc" + + +def test_update_subscription_tier_rejects_open_redirect( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription rejects success/cancel URLs outside the frontend origin.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": "https://evil.example.org/phish", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 422 + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_enterprise_blocked( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """ENTERPRISE users cannot self-service change tiers — must get 403.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.ENTERPRISE + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 403 + set_tier_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_releases_pending_change( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for the user's current tier releases any pending change. + + "Stay on my current tier" — the collapsed replacement for the old + /credits/subscription/cancel-pending route. Always calls + release_pending_subscription_schedule (idempotent when nothing is pending) + and returns the refreshed status with url="". Never creates a Checkout + Session — that would double-charge a user who double-clicks their own tier. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + feature_mock = mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "BUSINESS" + assert data["url"] == "" + release_mock.assert_awaited_once_with(TEST_USER_ID) + checkout_mock.assert_not_awaited() + # Same-tier branch short-circuits before the payment-flag check. + feature_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_no_pending_change_returns_status( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request when nothing is pending still returns status with url="".""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=False, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + assert data["url"] == "" + assert data["pending_tier"] is None + release_mock.assert_awaited_once_with(TEST_USER_ID) + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_stripe_error_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request surfaces a 502 when Stripe release fails. + + Carries forward the error contract from the removed + /credits/subscription/cancel-pending route so clients keep seeing 502 for + transient Stripe failures. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + side_effect=stripe.StripeError("network"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + assert "contact support" in response.json()["detail"].lower() + + +def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE schedules Stripe cancellation at period end. + + The DB tier must NOT be updated immediately — the customer.subscription.deleted + webhook fires at period end and downgrades to FREE then. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mock_cancel = mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + new_callable=AsyncMock, + ) + mock_set_tier = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + mock_cancel.assert_awaited_once() + mock_set_tier.assert_not_awaited() + + +def test_update_subscription_tier_free_cancel_failure_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage).""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + side_effect=stripe.StripeError( + "You did not provide an API key — internal detail that must not leak" + ), + ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 502 + detail = response.json()["detail"] + # The raw Stripe error message must not appear in the client-facing detail. + assert "API key" not in detail + assert "contact support" in detail.lower() + + +def test_stripe_webhook_unconfigured_secret_returns_503( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Stripe webhook endpoint returns 503 when STRIPE_WEBHOOK_SECRET is not set. + + An empty webhook secret allows HMAC forgery: an attacker can compute a valid + HMAC signature over the same empty key. The handler must reject all requests + when the secret is unconfigured rather than proceeding with signature verification. + """ + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="", + ) + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=fake"}, + ) + assert response.status_code == 503 + + +def test_stripe_webhook_dispatches_subscription_events( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/stripe_webhook routes customer.subscription.created to sync handler.""" + stripe_sub_obj = { + "id": "sub_test", + "customer": "cus_test", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro"}}]}, + } + event = { + "type": "customer.subscription.created", + "data": {"object": stripe_sub_obj}, + } + + # Ensure the webhook secret guard passes (non-empty secret required). + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(stripe_sub_obj) + + +def test_stripe_webhook_dispatches_invoice_payment_failed( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/stripe_webhook routes invoice.payment_failed to the failure handler.""" + invoice_obj = { + "customer": "cus_test", + "subscription": "sub_test", + "amount_due": 1999, + } + event = { + "type": "invoice.payment_failed", + "data": {"object": invoice_obj}, + } + + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + failure_mock = mocker.patch( + "backend.api.features.v1.handle_subscription_payment_failure", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + failure_mock.assert_awaited_once_with(invoice_obj) + + +def test_update_subscription_tier_paid_to_paid_modifies_subscription( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription modifies existing subscription for paid→paid changes.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes. + + When modify_stripe_subscription_for_tier returns False (no Stripe subscription + found — admin-granted tier), the endpoint must update the DB tier directly and + return 200 with url="", rather than falling through to Checkout Session creation. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + # Return False = no Stripe subscription (admin-granted tier) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=False, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + # DB tier updated directly — no Stripe Checkout Session created + set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription returns 502 when Stripe modification fails.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + side_effect=stripe.StripeError("connection error"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + + +def test_update_subscription_tier_free_no_stripe_subscription( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE when no Stripe subscription exists updates DB tier directly. + + Admin-granted paid tiers have no associated Stripe subscription. When such a + user requests a self-service downgrade, cancel_stripe_subscription returns False + (nothing to cancel), so the endpoint must immediately call set_subscription_tier + rather than waiting for a webhook that will never arrive. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + # Simulate no active Stripe subscriptions — returns False + cancel_mock = mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + new_callable=AsyncMock, + return_value=False, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + assert response.json()["url"] == "" + cancel_mock.assert_awaited_once_with(TEST_USER_ID) + # DB tier must be updated immediately — no webhook will fire for a missing sub + set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE) + + +def test_get_subscription_status_includes_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription exposes pending_tier and pending_tier_effective_at.""" + import datetime as dt + + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc) + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=(SubscriptionTier.PRO, effective_at), + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] == "PRO" + assert data["pending_tier_effective_at"] is not None + + +def test_get_subscription_status_no_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """When no pending change exists the response omits pending_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] is None + assert data["pending_tier_effective_at"] is None + + +def test_update_subscription_tier_downgrade_paid_to_paid_schedules( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO) + checkout_mock.assert_not_awaited() + + +def test_stripe_webhook_dispatches_subscription_schedule_released( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.released routes to sync_subscription_schedule_from_stripe.""" + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.released", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(schedule_obj) + + +def test_stripe_webhook_ignores_subscription_schedule_updated( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.updated must NOT dispatch: our own + SubscriptionSchedule.create/.modify calls fire this event and would + otherwise loop redundant traffic through the sync handler. State + transitions we care about surface via .released/.completed, and phase + advance to a new price is already covered by customer.subscription.updated. + """ + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.updated", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index d208114f95..12a31e6bd1 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -5,7 +5,8 @@ import time import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import Annotated, Any, Sequence, get_args +from typing import Annotated, Any, Literal, Sequence, cast, get_args +from urllib.parse import urlparse import pydantic import stripe @@ -24,10 +25,12 @@ from fastapi import ( UploadFile, ) from fastapi.concurrency import run_in_threadpool -from pydantic import BaseModel +from prisma.enums import SubscriptionTier +from pydantic import BaseModel, Field from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict +from backend.api.features.workspace.routes import create_file_download_response from backend.api.model import ( CreateAPIKeyRequest, CreateAPIKeyResponse, @@ -47,12 +50,24 @@ from backend.data.auth import api_key as api_key_db from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.credit import ( AutoTopUpConfig, + PendingChangeUnknown, RefundRequest, TransactionHistory, UserCredit, + cancel_stripe_subscription, + create_subscription_checkout, get_auto_top_up, + get_pending_subscription_change, + get_proration_credit_cents, + get_subscription_price_id, get_user_credit_model, + handle_subscription_payment_failure, + modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_auto_top_up, + set_subscription_tier, + sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) from backend.data.graph import GraphSettings from backend.data.model import CredentialsMetaInput, UserOnboarding @@ -82,6 +97,7 @@ from backend.data.user import ( update_user_notification_preference, update_user_timezone, ) +from backend.data.workspace import get_workspace_file_by_id from backend.executor import scheduler from backend.executor import utils as execution_utils from backend.integrations.webhooks.graph_lifecycle_hooks import ( @@ -661,9 +677,12 @@ async def configure_user_auto_top_up( raise HTTPException(status_code=422, detail=str(e)) raise - await set_auto_top_up( - user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount) - ) + try: + await set_auto_top_up( + user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount) + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) return "Auto top-up settings updated" @@ -679,41 +698,430 @@ async def get_user_auto_top_up( return await get_auto_top_up(user_id) +class SubscriptionTierRequest(BaseModel): + tier: Literal["FREE", "PRO", "BUSINESS"] + success_url: str = "" + cancel_url: str = "" + + +class SubscriptionStatusResponse(BaseModel): + tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"] + monthly_cost: int # amount in cents (Stripe convention) + tier_costs: dict[str, int] # tier name -> amount in cents + proration_credit_cents: int # unused portion of current sub to convert on upgrade + pending_tier: Optional[Literal["FREE", "PRO", "BUSINESS"]] = None + pending_tier_effective_at: Optional[datetime] = None + url: str = Field( + default="", + description=( + "Populated only when POST /credits/subscription starts a Stripe Checkout" + " Session (FREE → paid upgrade). Empty string in all other branches —" + " the client redirects to this URL when non-empty." + ), + ) + + +def _validate_checkout_redirect_url(url: str) -> bool: + """Return True if `url` matches the configured frontend origin. + + Prevents open-redirect: attackers must not be able to supply arbitrary + success_url/cancel_url that Stripe will redirect users to after checkout. + + Pre-parse rejection rules (applied before urlparse): + - Backslashes (``\\``) are normalised differently across parsers/browsers. + - Control characters (U+0000–U+001F) are not valid in URLs and may confuse + some URL-parsing implementations. + """ + # Reject characters that can confuse URL parsers before any parsing. + if "\\" in url: + return False + if any(ord(c) < 0x20 for c in url): + return False + + allowed = settings.config.frontend_base_url or settings.config.platform_base_url + if not allowed: + # No configured origin — refuse to validate rather than allow arbitrary URLs. + return False + try: + parsed = urlparse(url) + allowed_parsed = urlparse(allowed) + except ValueError: + return False + if parsed.scheme not in ("http", "https"): + return False + # Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component + # can trick browsers into connecting to a different host than displayed. + # ``@`` in query/fragment is harmless and must be allowed. + if "@" in parsed.netloc: + return False + return ( + parsed.scheme == allowed_parsed.scheme + and parsed.netloc == allowed_parsed.netloc + ) + + +@cached(ttl_seconds=300, maxsize=32, cache_none=False) +async def _get_stripe_price_amount(price_id: str) -> int | None: + """Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes. + + Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out + of caching the ``None`` sentinel so the next request retries Stripe instead + of being served a stale "no price" for the rest of the TTL window. Callers + should treat ``None`` as an unknown price and fall back to 0. + + Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on + every GET /credits/subscription page load and reduces quota consumption. + """ + try: + price = await run_in_threadpool(stripe.Price.retrieve, price_id) + return price.unit_amount or 0 + except stripe.StripeError: + logger.warning( + "Failed to retrieve Stripe price %s — returning None (not cached)", + price_id, + ) + return None + + +@v1_router.get( + path="/credits/subscription", + summary="Get subscription tier, current cost, and all tier costs", + operation_id="getSubscriptionStatus", + tags=["credits"], + dependencies=[Security(requires_user)], +) +async def get_subscription_status( + user_id: Annotated[str, Security(get_user_id)], +) -> SubscriptionStatusResponse: + user = await get_user_by_id(user_id) + tier = user.subscription_tier or SubscriptionTier.FREE + + paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS] + price_ids = await asyncio.gather( + *[get_subscription_price_id(t) for t in paid_tiers] + ) + + tier_costs: dict[str, int] = { + SubscriptionTier.FREE.value: 0, + SubscriptionTier.ENTERPRISE.value: 0, + } + + async def _cost(pid: str | None) -> int: + return (await _get_stripe_price_amount(pid) or 0) if pid else 0 + + costs = await asyncio.gather(*[_cost(pid) for pid in price_ids]) + for t, cost in zip(paid_tiers, costs): + tier_costs[t.value] = cost + + current_monthly_cost = tier_costs.get(tier.value, 0) + proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost) + + try: + pending = await get_pending_subscription_change(user_id) + except (stripe.StripeError, PendingChangeUnknown): + # Swallow Stripe-side failures (rate limits, transient network) AND + # PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both + # propagate past the cache so the next request retries fresh instead + # of serving a stale None for the TTL window. Let real bugs (KeyError, + # AttributeError, etc.) propagate so they surface in Sentry. + logger.exception( + "get_subscription_status: failed to resolve pending change for user %s", + user_id, + ) + pending = None + + response = SubscriptionStatusResponse( + tier=tier.value, + monthly_cost=current_monthly_cost, + tier_costs=tier_costs, + proration_credit_cents=proration_credit, + ) + if pending is not None: + pending_tier_enum, pending_effective_at = pending + if pending_tier_enum == SubscriptionTier.FREE: + response.pending_tier = "FREE" + elif pending_tier_enum == SubscriptionTier.PRO: + response.pending_tier = "PRO" + elif pending_tier_enum == SubscriptionTier.BUSINESS: + response.pending_tier = "BUSINESS" + if response.pending_tier is not None: + response.pending_tier_effective_at = pending_effective_at + return response + + +@v1_router.post( + path="/credits/subscription", + summary="Update subscription tier or start a Stripe Checkout session", + operation_id="updateSubscriptionTier", + tags=["credits"], + dependencies=[Security(requires_user)], +) +async def update_subscription_tier( + request: SubscriptionTierRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> SubscriptionStatusResponse: + # Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type. + tier = SubscriptionTier(request.tier) + + # ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users. + user = await get_user_by_id(user_id) + if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE: + raise HTTPException( + status_code=403, + detail="ENTERPRISE subscription changes must be managed by an administrator", + ) + + # Same-tier request = "stay on my current tier" = cancel any pending + # scheduled change (paid→paid downgrade or paid→FREE cancel). This is the + # collapsed behaviour that replaces the old /credits/subscription/cancel-pending + # route. Safe when no pending change exists: release_pending_subscription_schedule + # returns False and we simply return the current status. + if (user.subscription_tier or SubscriptionTier.FREE) == tier: + try: + await release_pending_subscription_schedule(user_id) + except stripe.StripeError as e: + logger.exception( + "Stripe error releasing pending subscription change for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel the pending subscription change right now. " + "Please try again or contact support." + ), + ) + return await get_subscription_status(user_id) + + payment_enabled = await is_feature_enabled( + Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False + ) + + # Downgrade to FREE: schedule Stripe cancellation at period end so the user + # keeps their tier for the time they already paid for. The DB tier is NOT + # updated here when a subscription exists — the customer.subscription.deleted + # webhook fires at period end and downgrades to FREE then. + # Exception: if the user has no active Stripe subscription (e.g. admin-granted + # tier), cancel_stripe_subscription returns False and we update the DB tier + # immediately since no webhook will ever fire. + # When payment is disabled entirely, update the DB tier directly. + if tier == SubscriptionTier.FREE: + if payment_enabled: + try: + had_subscription = await cancel_stripe_subscription(user_id) + except stripe.StripeError as e: + # Log full Stripe error server-side but return a generic message + # to the client — raw Stripe errors can leak customer/sub IDs and + # infrastructure config details. + logger.exception( + "Stripe error cancelling subscription for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel your subscription right now. " + "Please try again or contact support." + ), + ) + if not had_subscription: + # No active Stripe subscription found — the user was on an + # admin-granted tier. Update DB immediately since the + # subscription.deleted webhook will never fire. + await set_subscription_tier(user_id, tier) + return await get_subscription_status(user_id) + await set_subscription_tier(user_id, tier) + return await get_subscription_status(user_id) + + # Paid tier changes require payment to be enabled — block self-service upgrades + # when the flag is off. Admins use the /api/admin/ routes to set tiers directly. + if not payment_enabled: + raise HTTPException( + status_code=422, + detail=f"Subscription not available for tier {tier}", + ) + + # Paid→paid tier change: if the user already has a Stripe subscription, + # modify it in-place with proration instead of creating a new Checkout + # Session. This preserves remaining paid time and avoids double-charging. + # The customer.subscription.updated webhook fires and updates the DB tier. + current_tier = user.subscription_tier or SubscriptionTier.FREE + if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS): + try: + modified = await modify_stripe_subscription_for_tier(user_id, tier) + if modified: + return await get_subscription_status(user_id) + # modify_stripe_subscription_for_tier returns False when no active + # Stripe subscription exists — i.e. the user has an admin-granted + # paid tier with no Stripe record. In that case, update the DB + # tier directly (same as the FREE-downgrade path for admin-granted + # users) rather than sending them through a new Checkout Session. + await set_subscription_tier(user_id, tier) + return await get_subscription_status(user_id) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except stripe.StripeError as e: + logger.exception( + "Stripe error modifying subscription for user %s: %s", user_id, e + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to update your subscription right now. " + "Please try again or contact support." + ), + ) + + # Paid upgrade from FREE → create Stripe Checkout Session. + if not request.success_url or not request.cancel_url: + raise HTTPException( + status_code=422, + detail="success_url and cancel_url are required for paid tier upgrades", + ) + # Open-redirect protection: both URLs must point to the configured frontend + # origin, otherwise an attacker could use our Stripe integration as a + # redirector to arbitrary phishing sites. + # + # Fail early with a clear 503 if the server is misconfigured (neither + # frontend_base_url nor platform_base_url set), so operators get an + # actionable error instead of the misleading "must match the platform + # frontend origin" 422 that _validate_checkout_redirect_url would otherwise + # produce when `allowed` is empty. + if not (settings.config.frontend_base_url or settings.config.platform_base_url): + logger.error( + "update_subscription_tier: neither frontend_base_url nor " + "platform_base_url is configured; cannot validate checkout redirect URLs" + ) + raise HTTPException( + status_code=503, + detail=( + "Payment redirect URLs cannot be validated: " + "frontend_base_url or platform_base_url must be set on the server." + ), + ) + if not _validate_checkout_redirect_url( + request.success_url + ) or not _validate_checkout_redirect_url(request.cancel_url): + raise HTTPException( + status_code=422, + detail="success_url and cancel_url must match the platform frontend origin", + ) + try: + url = await create_subscription_checkout( + user_id=user_id, + tier=tier, + success_url=request.success_url, + cancel_url=request.cancel_url, + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except stripe.StripeError as e: + logger.exception( + "Stripe error creating checkout session for user %s: %s", user_id, e + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to start checkout right now. " + "Please try again or contact support." + ), + ) + + status = await get_subscription_status(user_id) + status.url = url + return status + + @v1_router.post( path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"] ) async def stripe_webhook(request: Request): + webhook_secret = settings.secrets.stripe_webhook_secret + if not webhook_secret: + # Guard: an empty secret allows HMAC forgery (attacker can compute a valid + # signature over the same empty key). Reject all webhook calls when unconfigured. + logger.error( + "stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — " + "rejecting request to prevent signature bypass" + ) + raise HTTPException(status_code=503, detail="Webhook not configured") + # Get the raw request body payload = await request.body() # Get the signature header sig_header = request.headers.get("stripe-signature") try: - event = stripe.Webhook.construct_event( - payload, sig_header, settings.secrets.stripe_webhook_secret - ) - except ValueError as e: + event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret) + except ValueError: # Invalid payload - raise HTTPException( - status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}" - ) - except stripe.SignatureVerificationError as e: + raise HTTPException(status_code=400, detail="Invalid payload") + except stripe.SignatureVerificationError: # Invalid signature - raise HTTPException( - status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}" + raise HTTPException(status_code=400, detail="Invalid signature") + + # Defensive payload extraction. A malformed payload (missing/non-dict + # `data.object`, missing `id`) would otherwise raise KeyError/TypeError + # AFTER signature verification — which Stripe interprets as a delivery + # failure and retries forever, while spamming Sentry with no useful info. + # Acknowledge with 200 and a warning so Stripe stops retrying. + event_type = event.get("type", "") + event_data = event.get("data") or {} + data_object = event_data.get("object") if isinstance(event_data, dict) else None + if not isinstance(data_object, dict): + logger.warning( + "stripe_webhook: %s missing or non-dict data.object; ignoring", + event_type, ) + return Response(status_code=200) - if ( - event["type"] == "checkout.session.completed" - or event["type"] == "checkout.session.async_payment_succeeded" + if event_type in ( + "checkout.session.completed", + "checkout.session.async_payment_succeeded", ): - await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"]) + session_id = data_object.get("id") + if not session_id: + logger.warning( + "stripe_webhook: %s missing data.object.id; ignoring", event_type + ) + return Response(status_code=200) + await UserCredit().fulfill_checkout(session_id=session_id) - if event["type"] == "charge.dispute.created": - await UserCredit().handle_dispute(event["data"]["object"]) + if event_type in ( + "customer.subscription.created", + "customer.subscription.updated", + "customer.subscription.deleted", + ): + await sync_subscription_from_stripe(data_object) - if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed": - await UserCredit().deduct_credits(event["data"]["object"]) + # `subscription_schedule.updated` is deliberately omitted: our own + # `SubscriptionSchedule.create` + `.modify` calls in + # `_schedule_downgrade_at_period_end` would fire that event right back at us + # and loop redundant traffic through this handler. We only care about state + # transitions (released / completed); phase advance to the new price is + # already covered by `customer.subscription.updated`. + if event_type in ( + "subscription_schedule.released", + "subscription_schedule.completed", + ): + await sync_subscription_schedule_from_stripe(data_object) + + if event_type == "invoice.payment_failed": + await handle_subscription_payment_failure(data_object) + + # `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects + # (Dispute/Refund). The Stripe webhook payload's `data.object` is a + # StripeObject (a dict subclass) carrying that runtime shape, so we cast + # to satisfy the type checker without changing runtime behaviour. + if event_type == "charge.dispute.created": + await UserCredit().handle_dispute(cast(stripe.Dispute, data_object)) + + if event_type == "refund.created" or event_type == "charge.dispute.closed": + await UserCredit().deduct_credits( + cast("stripe.Refund | stripe.Dispute", data_object) + ) return Response(status_code=200) @@ -1297,6 +1705,10 @@ async def enable_execution_sharing( # Generate a unique share token share_token = str(uuid.uuid4()) + # Remove stale allowlist records before updating the token — prevents a + # window where old records + new token could coexist. + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Update the execution with share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1306,6 +1718,14 @@ async def enable_execution_sharing( shared_at=datetime.now(timezone.utc), ) + # Create allowlist of workspace files referenced in outputs + await execution_db.create_shared_execution_files( + execution_id=graph_exec_id, + share_token=share_token, + user_id=user_id, + outputs=execution.outputs, + ) + # Return the share URL frontend_url = settings.config.frontend_base_url or "http://localhost:3000" share_url = f"{frontend_url}/share/{share_token}" @@ -1331,6 +1751,9 @@ async def disable_execution_sharing( if not execution: raise HTTPException(status_code=404, detail="Execution not found") + # Remove shared file allowlist records + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Remove share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1356,6 +1779,43 @@ async def get_shared_execution( return execution +@v1_router.get( + "/public/shared/{share_token}/files/{file_id}/download", + summary="Download a file from a shared execution", + operation_id="download_shared_file", + tags=["graphs"], +) +async def download_shared_file( + share_token: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], + file_id: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], +) -> Response: + """Download a workspace file from a shared execution (no auth required). + + Validates that the file was explicitly exposed when sharing was enabled. + Returns a uniform 404 for all failure modes to prevent enumeration attacks. + """ + # Single-query validation against the allowlist + execution_id = await execution_db.get_shared_execution_file( + share_token=share_token, file_id=file_id + ) + if not execution_id: + raise HTTPException(status_code=404, detail="Not found") + + # Look up the actual file (no workspace scoping needed — the allowlist + # already validated that this file belongs to the shared execution) + file = await get_workspace_file_by_id(file_id) + if not file: + raise HTTPException(status_code=404, detail="Not found") + + return await create_file_download_response(file, inline=True) + + ######################################################## ##################### Schedules ######################## ######################################################## diff --git a/autogpt_platform/backend/backend/api/features/v1_share_test.py b/autogpt_platform/backend/backend/api/features/v1_share_test.py new file mode 100644 index 0000000000..de5d14ad80 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/v1_share_test.py @@ -0,0 +1,157 @@ +"""Tests for the public shared file download endpoint.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.responses import Response + +from backend.api.features.v1 import v1_router +from backend.data.workspace import WorkspaceFile + +app = FastAPI() +app.include_router(v1_router, prefix="/api") + +VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000" +VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + +def _make_workspace_file(**overrides) -> WorkspaceFile: + defaults = { + "id": VALID_FILE_ID, + "workspace_id": "ws-001", + "created_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "name": "image.png", + "path": "/image.png", + "storage_path": "local://uploads/image.png", + "mime_type": "image/png", + "size_bytes": 4, + "checksum": None, + "is_deleted": False, + "deleted_at": None, + "metadata": {}, + } + defaults.update(overrides) + return WorkspaceFile(**defaults) + + +def _mock_download_response(**kwargs): + """Return an AsyncMock that resolves to a Response with inline disposition.""" + + async def _handler(file, *, inline=False): + return Response( + content=b"\x89PNG", + media_type="image/png", + headers={ + "Content-Disposition": ( + 'inline; filename="image.png"' + if inline + else 'attachment; filename="image.png"' + ), + "Content-Length": "4", + }, + ) + + return _handler + + +class TestDownloadSharedFile: + """Tests for GET /api/public/shared/{token}/files/{id}/download.""" + + @pytest.fixture(autouse=True) + def _client(self): + self.client = TestClient(app, raise_server_exceptions=False) + + def test_valid_token_and_file_returns_inline_content(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=_make_workspace_file(), + ), + patch( + "backend.api.features.v1.create_file_download_response", + side_effect=_mock_download_response(), + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert response.status_code == 200 + assert response.content == b"\x89PNG" + assert "inline" in response.headers["Content-Disposition"] + + def test_invalid_token_format_returns_422(self): + response = self.client.get( + f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 422 + + def test_token_not_in_allowlist_returns_404(self): + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_file_missing_from_workspace_returns_404(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_uniform_404_prevents_enumeration(self): + """Both failure modes produce identical 404 — no information leak.""" + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + resp_no_allow = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + resp_no_file = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert resp_no_allow.status_code == 404 + assert resp_no_file.status_code == 404 + assert resp_no_allow.json() == resp_no_file.json() diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py index 8ca339edbd..c22cc445c4 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -12,7 +12,7 @@ import fastapi from autogpt_libs.auth.dependencies import get_user_id, requires_user from fastapi import Query, UploadFile from fastapi.responses import Response -from pydantic import BaseModel +from pydantic import BaseModel, Field from backend.data.workspace import ( WorkspaceFile, @@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager from backend.util.workspace_storage import get_workspace_storage -def _sanitize_filename_for_header(filename: str) -> str: +def _sanitize_filename_for_header( + filename: str, disposition: str = "attachment" +) -> str: """ Sanitize filename for Content-Disposition header to prevent header injection. @@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str: # Check if filename has non-ASCII characters try: sanitized.encode("ascii") - return f'attachment; filename="{sanitized}"' + return f'{disposition}; filename="{sanitized}"' except UnicodeEncodeError: # Use RFC5987 encoding for UTF-8 filenames encoded = quote(sanitized, safe="") - return f"attachment; filename*=UTF-8''{encoded}" + return f"{disposition}; filename*=UTF-8''{encoded}" logger = logging.getLogger(__name__) @@ -58,19 +60,26 @@ router = fastapi.APIRouter( ) -def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response: +def _create_streaming_response( + content: bytes, file: WorkspaceFile, *, inline: bool = False +) -> Response: """Create a streaming response for file content.""" + disposition = _sanitize_filename_for_header( + file.name, disposition="inline" if inline else "attachment" + ) return Response( content=content, media_type=file.mime_type, headers={ - "Content-Disposition": _sanitize_filename_for_header(file.name), + "Content-Disposition": disposition, "Content-Length": str(len(content)), }, ) -async def _create_file_download_response(file: WorkspaceFile) -> Response: +async def create_file_download_response( + file: WorkspaceFile, *, inline: bool = False +) -> Response: """ Create a download response for a workspace file. @@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # For local storage, stream the file directly if file.storage_path.startswith("local://"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) # For GCS, try to redirect to signed URL, fall back to streaming try: @@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # If we got back an API path (fallback), stream directly instead if url.startswith("/api/"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) return fastapi.responses.RedirectResponse(url=url, status_code=302) except Exception as e: # Log the signed URL failure with context @@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # Fall back to streaming directly from GCS try: content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) except Exception as fallback_error: logger.error( f"Fallback streaming also failed for file {file.id} " @@ -131,9 +140,26 @@ class StorageUsageResponse(BaseModel): file_count: int +class WorkspaceFileItem(BaseModel): + id: str + name: str + path: str + mime_type: str + size_bytes: int + metadata: dict = Field(default_factory=dict) + created_at: str + + +class ListFilesResponse(BaseModel): + files: list[WorkspaceFileItem] + offset: int = 0 + has_more: bool = False + + @router.get( "/files/{file_id}/download", summary="Download file by ID", + operation_id="getWorkspaceDownloadFileById", ) async def download_file( user_id: Annotated[str, fastapi.Security(get_user_id)], @@ -152,12 +178,13 @@ async def download_file( if file is None: raise fastapi.HTTPException(status_code=404, detail="File not found") - return await _create_file_download_response(file) + return await create_file_download_response(file) @router.delete( "/files/{file_id}", summary="Delete a workspace file", + operation_id="deleteWorkspaceFile", ) async def delete_workspace_file( user_id: Annotated[str, fastapi.Security(get_user_id)], @@ -183,6 +210,7 @@ async def delete_workspace_file( @router.post( "/files/upload", summary="Upload file to workspace", + operation_id="uploadWorkspaceFile", ) async def upload_file( user_id: Annotated[str, fastapi.Security(get_user_id)], @@ -196,6 +224,9 @@ async def upload_file( Files are stored in session-scoped paths when session_id is provided, so the agent's session-scoped tools can discover them automatically. """ + # Empty-string session_id drops session scoping; normalize to None. + session_id = session_id or None + config = Config() # Sanitize filename — strip any directory components @@ -250,16 +281,27 @@ async def upload_file( manager = WorkspaceManager(user_id, workspace.id, session_id) try: workspace_file = await manager.write_file( - content, filename, overwrite=overwrite + content, filename, overwrite=overwrite, metadata={"origin": "user-upload"} ) except ValueError as e: - raise fastapi.HTTPException(status_code=409, detail=str(e)) from e + # write_file raises ValueError for both path-conflict and size-limit + # cases; map each to its correct HTTP status. + message = str(e) + if message.startswith("File too large"): + raise fastapi.HTTPException(status_code=413, detail=message) from e + raise fastapi.HTTPException(status_code=409, detail=message) from e # Post-write storage check — eliminates TOCTOU race on the quota. # If a concurrent upload pushed us over the limit, undo this write. new_total = await get_workspace_total_size(workspace.id) if storage_limit_bytes and new_total > storage_limit_bytes: - await soft_delete_workspace_file(workspace_file.id, workspace.id) + try: + await soft_delete_workspace_file(workspace_file.id, workspace.id) + except Exception as e: + logger.warning( + f"Failed to soft-delete over-quota file {workspace_file.id} " + f"in workspace {workspace.id}: {e}" + ) raise fastapi.HTTPException( status_code=413, detail={ @@ -281,6 +323,7 @@ async def upload_file( @router.get( "/storage/usage", summary="Get workspace storage usage", + operation_id="getWorkspaceStorageUsage", ) async def get_storage_usage( user_id: Annotated[str, fastapi.Security(get_user_id)], @@ -301,3 +344,57 @@ async def get_storage_usage( used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0, file_count=file_count, ) + + +@router.get( + "/files", + summary="List workspace files", + operation_id="listWorkspaceFiles", +) +async def list_workspace_files( + user_id: Annotated[str, fastapi.Security(get_user_id)], + session_id: str | None = Query(default=None), + limit: int = Query(default=200, ge=1, le=1000), + offset: int = Query(default=0, ge=0), +) -> ListFilesResponse: + """ + List files in the user's workspace. + + When session_id is provided, only files for that session are returned. + Otherwise, all files across sessions are listed. Results are paginated + via `limit`/`offset`; `has_more` indicates whether additional pages exist. + """ + workspace = await get_or_create_workspace(user_id) + + # Treat empty-string session_id the same as omitted — an empty value + # would otherwise silently list files across every session instead of + # scoping to one. + session_id = session_id or None + + manager = WorkspaceManager(user_id, workspace.id, session_id) + include_all = session_id is None + # Fetch one extra to compute has_more without a separate count query. + files = await manager.list_files( + limit=limit + 1, + offset=offset, + include_all_sessions=include_all, + ) + has_more = len(files) > limit + page = files[:limit] + + return ListFilesResponse( + files=[ + WorkspaceFileItem( + id=f.id, + name=f.name, + path=f.path, + mime_type=f.mime_type, + size_bytes=f.size_bytes, + metadata=f.metadata or {}, + created_at=f.created_at.isoformat(), + ) + for f in page + ], + offset=offset, + has_more=has_more, + ) diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py index 76da67aaa1..ffc712014f 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py @@ -1,48 +1,28 @@ -"""Tests for workspace file upload and download routes.""" - import io from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch import fastapi import fastapi.testclient import pytest -import pytest_mock -from backend.api.features.workspace import routes as workspace_routes -from backend.data.workspace import WorkspaceFile +from backend.api.features.workspace.routes import router +from backend.data.workspace import Workspace, WorkspaceFile app = fastapi.FastAPI() -app.include_router(workspace_routes.router) +app.include_router(router) @app.exception_handler(ValueError) async def _value_error_handler( request: fastapi.Request, exc: ValueError ) -> fastapi.responses.JSONResponse: - """Mirror the production ValueError → 400 mapping from rest_api.py.""" + """Mirror the production ValueError → 400 mapping from the REST app.""" return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)}) client = fastapi.testclient.TestClient(app) -TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" - -MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})() - -_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc) - -MOCK_FILE = WorkspaceFile( - id="file-aaa-bbb", - workspace_id="ws-1", - created_at=_NOW, - updated_at=_NOW, - name="hello.txt", - path="/session/hello.txt", - mime_type="text/plain", - size_bytes=13, - storage_path="local://hello.txt", -) - @pytest.fixture(autouse=True) def setup_app_auth(mock_jwt_user): @@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user): app.dependency_overrides.clear() +def _make_workspace(user_id: str = "test-user-id") -> Workspace: + return Workspace( + id="ws-001", + user_id=user_id, + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def _make_file(**overrides) -> WorkspaceFile: + defaults = { + "id": "file-001", + "workspace_id": "ws-001", + "created_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "name": "test.txt", + "path": "/test.txt", + "storage_path": "local://test.txt", + "mime_type": "text/plain", + "size_bytes": 100, + "checksum": None, + "is_deleted": False, + "deleted_at": None, + "metadata": {}, + } + defaults.update(overrides) + return WorkspaceFile(**defaults) + + +def _make_file_mock(**overrides) -> MagicMock: + """Create a mock WorkspaceFile to simulate DB records with null fields.""" + defaults = { + "id": "file-001", + "name": "test.txt", + "path": "/test.txt", + "mime_type": "text/plain", + "size_bytes": 100, + "metadata": {}, + "created_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + } + defaults.update(overrides) + mock = MagicMock(spec=WorkspaceFile) + for k, v in defaults.items(): + setattr(mock, k, v) + return mock + + +# -- list_workspace_files tests -- + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace): + mock_get_workspace.return_value = _make_workspace() + files = [ + _make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}), + _make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}), + ] + mock_instance = AsyncMock() + mock_instance.list_files.return_value = files + mock_manager_cls.return_value = mock_instance + + response = client.get("/files") + assert response.status_code == 200 + + data = response.json() + assert len(data["files"]) == 2 + assert data["has_more"] is False + assert data["offset"] == 0 + assert data["files"][0]["id"] == "f1" + assert data["files"][0]["metadata"] == {"origin": "user-upload"} + assert data["files"][1]["id"] == "f2" + mock_instance.list_files.assert_called_once_with( + limit=201, offset=0, include_all_sessions=True + ) + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_scopes_to_session_when_provided( + mock_manager_cls, mock_get_workspace, test_user_id +): + mock_get_workspace.return_value = _make_workspace(user_id=test_user_id) + mock_instance = AsyncMock() + mock_instance.list_files.return_value = [] + mock_manager_cls.return_value = mock_instance + + response = client.get("/files?session_id=sess-123") + assert response.status_code == 200 + + data = response.json() + assert data["files"] == [] + assert data["has_more"] is False + mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123") + mock_instance.list_files.assert_called_once_with( + limit=201, offset=0, include_all_sessions=False + ) + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_null_metadata_coerced_to_empty_dict( + mock_manager_cls, mock_get_workspace +): + """Route uses `f.metadata or {}` for pre-existing files with null metadata.""" + mock_get_workspace.return_value = _make_workspace() + mock_instance = AsyncMock() + mock_instance.list_files.return_value = [_make_file_mock(metadata=None)] + mock_manager_cls.return_value = mock_instance + + response = client.get("/files") + assert response.status_code == 200 + assert response.json()["files"][0]["metadata"] == {} + + +# -- upload_file metadata tests -- + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.get_workspace_total_size") +@patch("backend.api.features.workspace.routes.scan_content_safe") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_upload_passes_user_upload_origin_metadata( + mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace +): + mock_get_workspace.return_value = _make_workspace() + mock_total_size.return_value = 100 + written = _make_file(id="new-file", name="doc.pdf") + mock_instance = AsyncMock() + mock_instance.write_file.return_value = written + mock_manager_cls.return_value = mock_instance + + response = client.post( + "/files/upload", + files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")}, + ) + assert response.status_code == 200 + + mock_instance.write_file.assert_called_once() + call_kwargs = mock_instance.write_file.call_args + assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"} + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.get_workspace_total_size") +@patch("backend.api.features.workspace.routes.scan_content_safe") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_upload_returns_409_on_file_conflict( + mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace +): + mock_get_workspace.return_value = _make_workspace() + mock_total_size.return_value = 100 + mock_instance = AsyncMock() + mock_instance.write_file.side_effect = ValueError("File already exists at path") + mock_manager_cls.return_value = mock_instance + + response = client.post( + "/files/upload", + files={"file": ("dup.txt", b"content", "text/plain")}, + ) + assert response.status_code == 409 + assert "already exists" in response.json()["detail"] + + +# -- Restored upload/download/delete security + invariant tests -- + + def _upload( filename: str = "hello.txt", content: bytes = b"Hello, world!", content_type: str = "text/plain", ): - """Helper to POST a file upload.""" return client.post( "/files/upload?session_id=sess-1", files={"file": (filename, io.BytesIO(content), content_type)}, ) -# ---- Happy path ---- +_MOCK_FILE = WorkspaceFile( + id="file-aaa-bbb", + workspace_id="ws-001", + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + name="hello.txt", + path="/sessions/sess-1/hello.txt", + mime_type="text/plain", + size_bytes=13, + storage_path="local://hello.txt", +) -def test_upload_happy_path(mocker: pytest_mock.MockFixture): +def test_upload_happy_path(mocker): mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", @@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture): return_value=None, ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, @@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture): assert data["size_bytes"] == 13 -# ---- Per-file size limit ---- - - -def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture): +def test_upload_exceeds_max_file_size(mocker): """Files larger than max_file_size_mb should be rejected with 413.""" cfg = mocker.patch("backend.api.features.workspace.routes.Config") cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big @@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture): assert response.status_code == 413 -# ---- Storage quota exceeded ---- - - -def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture): +def test_upload_storage_quota_exceeded(mocker): mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) - # Current usage already at limit mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", return_value=500 * 1024 * 1024, @@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture): assert "Storage limit exceeded" in response.text -# ---- Post-write quota race (B2) ---- - - -def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture): - """If a concurrent upload tips the total over the limit after write, - the file should be soft-deleted and 413 returned.""" +def test_upload_post_write_quota_race(mocker): + """Concurrent upload tipping over limit after write should soft-delete + 413.""" mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) - # Pre-write check passes (under limit), but post-write check fails mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", - side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit + side_effect=[0, 600 * 1024 * 1024], ) mocker.patch( "backend.api.features.workspace.routes.scan_content_safe", return_value=None, ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, @@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture): response = _upload() assert response.status_code == 413 - mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1") + mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001") -# ---- Any extension accepted (no allowlist) ---- - - -def test_upload_any_extension(mocker: pytest_mock.MockFixture): +def test_upload_any_extension(mocker): """Any file extension should be accepted — ClamAV is the security layer.""" mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", @@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture): return_value=None, ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, @@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture): assert response.status_code == 200 -# ---- Virus scan rejection ---- - - -def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture): +def test_upload_blocked_by_virus_scan(mocker): """Files flagged by ClamAV should be rejected and never written to storage.""" from backend.api.features.store.exceptions import VirusDetectedError mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", @@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture): side_effect=VirusDetectedError("Eicar-Test-Signature"), ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, @@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture): response = _upload(filename="evil.exe", content=b"X5O!P%@AP...") assert response.status_code == 400 - assert "Virus detected" in response.text mock_manager.write_file.assert_not_called() -# ---- No file extension ---- - - -def test_upload_file_without_extension(mocker: pytest_mock.MockFixture): +def test_upload_file_without_extension(mocker): """Files without an extension should be accepted and stored as-is.""" mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", @@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture): return_value=None, ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, @@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture): assert mock_manager.write_file.call_args[0][1] == "Makefile" -# ---- Filename sanitization (SF5) ---- - - -def test_upload_strips_path_components(mocker: pytest_mock.MockFixture): +def test_upload_strips_path_components(mocker): """Path-traversal filenames should be reduced to their basename.""" mocker.patch( "backend.api.features.workspace.routes.get_or_create_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_total_size", @@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture): return_value=None, ) mock_manager = mocker.MagicMock() - mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE) + mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE) mocker.patch( "backend.api.features.workspace.routes.WorkspaceManager", return_value=mock_manager, ) - # Filename with traversal _upload(filename="../../etc/passwd.txt") - # write_file should have been called with just the basename mock_manager.write_file.assert_called_once() call_args = mock_manager.write_file.call_args assert call_args[0][1] == "passwd.txt" -# ---- Download ---- - - -def test_download_file_not_found(mocker: pytest_mock.MockFixture): +def test_download_file_not_found(mocker): mocker.patch( "backend.api.features.workspace.routes.get_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mocker.patch( "backend.api.features.workspace.routes.get_workspace_file", @@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture): assert response.status_code == 404 -# ---- Delete ---- - - -def test_delete_file_success(mocker: pytest_mock.MockFixture): +def test_delete_file_success(mocker): """Deleting an existing file should return {"deleted": true}.""" mocker.patch( "backend.api.features.workspace.routes.get_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mock_manager = mocker.MagicMock() mock_manager.delete_file = mocker.AsyncMock(return_value=True) @@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture): mock_manager.delete_file.assert_called_once_with("file-aaa-bbb") -def test_delete_file_not_found(mocker: pytest_mock.MockFixture): +def test_delete_file_not_found(mocker): """Deleting a non-existent file should return 404.""" mocker.patch( "backend.api.features.workspace.routes.get_workspace", - return_value=MOCK_WORKSPACE, + return_value=_make_workspace(), ) mock_manager = mocker.MagicMock() mock_manager.delete_file = mocker.AsyncMock(return_value=False) @@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture): assert "File not found" in response.text -def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture): +def test_delete_file_no_workspace(mocker): """Deleting when user has no workspace should return 404.""" mocker.patch( "backend.api.features.workspace.routes.get_workspace", @@ -357,3 +480,341 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture): response = client.delete("/files/file-aaa-bbb") assert response.status_code == 404 assert "Workspace not found" in response.text + + +def test_upload_write_file_too_large_returns_413(mocker): + """write_file raises ValueError("File too large: …") → must map to 413.""" + mocker.patch( + "backend.api.features.workspace.routes.get_or_create_workspace", + return_value=_make_workspace(), + ) + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_total_size", + return_value=0, + ) + mocker.patch( + "backend.api.features.workspace.routes.scan_content_safe", + return_value=None, + ) + mock_manager = mocker.MagicMock() + mock_manager.write_file = mocker.AsyncMock( + side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit") + ) + mocker.patch( + "backend.api.features.workspace.routes.WorkspaceManager", + return_value=mock_manager, + ) + + response = _upload() + assert response.status_code == 413 + assert "File too large" in response.text + + +def test_upload_write_file_conflict_returns_409(mocker): + """Non-'File too large' ValueErrors from write_file stay as 409.""" + mocker.patch( + "backend.api.features.workspace.routes.get_or_create_workspace", + return_value=_make_workspace(), + ) + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_total_size", + return_value=0, + ) + mocker.patch( + "backend.api.features.workspace.routes.scan_content_safe", + return_value=None, + ) + mock_manager = mocker.MagicMock() + mock_manager.write_file = mocker.AsyncMock( + side_effect=ValueError("File already exists at path: /sessions/x/a.txt") + ) + mocker.patch( + "backend.api.features.workspace.routes.WorkspaceManager", + return_value=mock_manager, + ) + + response = _upload() + assert response.status_code == 409 + assert "already exists" in response.text + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_has_more_true_when_limit_exceeded( + mock_manager_cls, mock_get_workspace +): + """The limit+1 fetch trick must flip has_more=True and trim the page.""" + mock_get_workspace.return_value = _make_workspace() + # Backend was asked for limit+1=3, and returned exactly 3 items. + files = [ + _make_file(id="f1", name="a.txt"), + _make_file(id="f2", name="b.txt"), + _make_file(id="f3", name="c.txt"), + ] + mock_instance = AsyncMock() + mock_instance.list_files.return_value = files + mock_manager_cls.return_value = mock_instance + + response = client.get("/files?limit=2") + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is True + assert len(data["files"]) == 2 + assert data["files"][0]["id"] == "f1" + assert data["files"][1]["id"] == "f2" + mock_instance.list_files.assert_called_once_with( + limit=3, offset=0, include_all_sessions=True + ) + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_has_more_false_when_exactly_page_size( + mock_manager_cls, mock_get_workspace +): + """Exactly `limit` rows means we're on the last page — has_more=False.""" + mock_get_workspace.return_value = _make_workspace() + files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")] + mock_instance = AsyncMock() + mock_instance.list_files.return_value = files + mock_manager_cls.return_value = mock_instance + + response = client.get("/files?limit=2") + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is False + assert len(data["files"]) == 2 + + +@patch("backend.api.features.workspace.routes.get_or_create_workspace") +@patch("backend.api.features.workspace.routes.WorkspaceManager") +def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace): + mock_get_workspace.return_value = _make_workspace() + mock_instance = AsyncMock() + mock_instance.list_files.return_value = [] + mock_manager_cls.return_value = mock_instance + + response = client.get("/files?offset=50&limit=10") + assert response.status_code == 200 + assert response.json()["offset"] == 50 + mock_instance.list_files.assert_called_once_with( + limit=11, offset=50, include_all_sessions=True + ) + + +# -- _sanitize_filename_for_header tests -- + + +class TestSanitizeFilenameForHeader: + def test_simple_ascii_attachment(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("report.pdf") == ( + 'attachment; filename="report.pdf"' + ) + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("image.png", disposition="inline") == ( + 'inline; filename="image.png"' + ) + + def test_strips_cr_lf_null(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("a\rb\nc\x00d.txt") + assert "\r" not in result + assert "\n" not in result + assert "\x00" not in result + assert 'filename="abcd.txt"' in result + + def test_escapes_quotes(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header('file"name.txt') + assert 'filename="file\\"name.txt"' in result + + def test_header_injection_blocked(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true") + # CR/LF stripped — the remaining text is safely inside the quoted value + assert "\r" not in result + assert "\n" not in result + assert result == 'attachment; filename="evil.txtX-Injected: true"' + + def test_unicode_uses_rfc5987(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("日本語.pdf") + assert "filename*=UTF-8''" in result + assert "attachment" in result + + def test_unicode_inline(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("图片.png", disposition="inline") + assert result.startswith("inline; filename*=UTF-8''") + + def test_empty_filename(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("") + assert result == 'attachment; filename=""' + + +# -- _create_streaming_response tests -- + + +class TestCreateStreamingResponse: + def test_attachment_disposition_by_default(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="data.bin", mime_type="application/octet-stream") + response = _create_streaming_response(b"binary-data", file) + assert ( + response.headers["Content-Disposition"] == 'attachment; filename="data.bin"' + ) + assert response.headers["Content-Type"] == "application/octet-stream" + assert response.headers["Content-Length"] == "11" + assert response.body == b"binary-data" + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="photo.png", mime_type="image/png") + response = _create_streaming_response(b"\x89PNG", file, inline=True) + assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"' + assert response.headers["Content-Type"] == "image/png" + + def test_inline_sanitizes_filename(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name='evil"\r\n.txt', mime_type="text/plain") + response = _create_streaming_response(b"data", file, inline=True) + assert "\r" not in response.headers["Content-Disposition"] + assert "\n" not in response.headers["Content-Disposition"] + assert "inline" in response.headers["Content-Disposition"] + + def test_content_length_matches_body(self): + from backend.api.features.workspace.routes import _create_streaming_response + + content = b"x" * 1000 + file = _make_file(name="big.bin", mime_type="application/octet-stream") + response = _create_streaming_response(content, file) + assert response.headers["Content-Length"] == "1000" + + +# -- create_file_download_response tests -- + + +class TestCreateFileDownloadResponse: + @pytest.mark.asyncio + async def test_local_storage_returns_streaming_response(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"file contents" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/test.txt", + mime_type="text/plain", + ) + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"file contents" + assert "attachment" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_local_storage_inline(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"\x89PNG" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/photo.png", + mime_type="image/png", + name="photo.png", + ) + response = await create_file_download_response(file, inline=True) + assert "inline" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_gcs_redirect(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = ( + "https://storage.googleapis.com/signed-url" + ) + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.pdf") + response = await create_file_download_response(file) + assert response.status_code == 302 + assert ( + response.headers["location"] == "https://storage.googleapis.com/signed-url" + ) + + @pytest.mark.asyncio + async def test_gcs_api_fallback_streams_directly(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = "/api/fallback" + mock_storage.retrieve.return_value = b"fallback content" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"fallback content" + + @pytest.mark.asyncio + async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.return_value = b"streamed" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"streamed" + + @pytest.mark.asyncio + async def test_gcs_total_failure_raises(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.side_effect = RuntimeError("Also failed") + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + with pytest.raises(RuntimeError, match="Also failed"): + await create_file_download_response(file) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 6f7af95611..abe261b725 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -17,7 +17,9 @@ from fastapi.routing import APIRoute from prisma.errors import PrismaError import backend.api.features.admin.credit_admin_routes +import backend.api.features.admin.diagnostics_admin_routes import backend.api.features.admin.execution_analytics_routes +import backend.api.features.admin.platform_cost_routes import backend.api.features.admin.rate_limit_admin_routes import backend.api.features.admin.store_admin_routes import backend.api.features.builder @@ -30,6 +32,7 @@ import backend.api.features.library.routes import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes +import backend.api.features.platform_linking.routes import backend.api.features.postmark.postmark import backend.api.features.store.model import backend.api.features.store.routes @@ -319,6 +322,11 @@ app.include_router( tags=["v2", "admin"], prefix="/api/credits", ) +app.include_router( + backend.api.features.admin.diagnostics_admin_routes.router, + tags=["v2", "admin"], + prefix="/api", +) app.include_router( backend.api.features.admin.execution_analytics_routes.router, tags=["v2", "admin"], @@ -329,6 +337,11 @@ app.include_router( tags=["v2", "admin"], prefix="/api/copilot", ) +app.include_router( + backend.api.features.admin.platform_cost_routes.router, + tags=["v2", "admin"], + prefix="/api/admin", +) app.include_router( backend.api.features.executions.review.routes.router, tags=["v2", "executions", "review"], @@ -366,6 +379,11 @@ app.include_router( tags=["oauth"], prefix="/api/oauth", ) +app.include_router( + backend.api.features.platform_linking.routes.router, + tags=["platform-linking"], + prefix="/api/platform-linking", +) app.mount("/external-api", external_api) diff --git a/autogpt_platform/backend/backend/app.py b/autogpt_platform/backend/backend/app.py index 236f098761..534f385009 100644 --- a/autogpt_platform/backend/backend/app.py +++ b/autogpt_platform/backend/backend/app.py @@ -42,11 +42,13 @@ def main(**kwargs): from backend.data.db_manager import DatabaseManager from backend.executor import ExecutionManager, Scheduler from backend.notifications import NotificationManager + from backend.platform_linking.manager import PlatformLinkingManager run_processes( DatabaseManager().set_log_level("warning"), Scheduler(), NotificationManager(), + PlatformLinkingManager(), WebsocketServer(), AgentServer(), ExecutionManager(), diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 56986d15c4..1cc29bd6d4 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -25,6 +25,7 @@ from backend.data.model import ( Credentials, CredentialsFieldInfo, CredentialsMetaInput, + NodeExecutionStats, SchemaField, is_credentials_field_name, ) @@ -43,7 +44,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from backend.data.execution import ExecutionContext - from backend.data.model import ContributorDetails, NodeExecutionStats + from backend.data.model import ContributorDetails from ..data.graph import Link @@ -167,9 +168,31 @@ class BlockSchema(BaseModel): return cls.cached_jsonschema @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: + schema = cls.jsonschema() + if exclude_fields: + # Drop the excluded fields from both the properties and the + # ``required`` list so jsonschema doesn't flag them as missing. + # Used by the dry-run path to skip credentials validation while + # still validating the remaining block inputs. + schema = { + **schema, + "properties": { + k: v + for k, v in schema.get("properties", {}).items() + if k not in exclude_fields + }, + "required": [ + r for r in schema.get("required", []) if r not in exclude_fields + ], + } + data = {k: v for k, v in data.items() if k not in exclude_fields} return json.validate_with_jsonschema( - schema=cls.jsonschema(), + schema=schema, data={k: v for k, v in data.items() if v is not None}, ) @@ -420,6 +443,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig): class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): _optimized_description: ClassVar[str | None] = None + def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int: + """Return extra runtime cost to charge after this block run completes. + + Called by the executor after a block finishes with COMPLETED status. + The return value is the number of additional base-cost credits to + charge beyond the single credit already collected by charge_usage + at the start of execution. Defaults to 0 (no extra charges). + + Override in blocks (e.g. OrchestratorBlock) that make multiple LLM + calls within one run and should be billed per call. + """ + return 0 + def __init__( self, id: str = "", @@ -455,8 +491,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): disabled: If the block is disabled, it will not be available for execution. static_output: Whether the output links of the block are static by default. """ - from backend.data.model import NodeExecutionStats - self.id = id self.input_schema = input_schema self.output_schema = output_schema @@ -474,7 +508,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): self.is_sensitive_action = is_sensitive_action # Read from ClassVar set by initialize_blocks() self.optimized_description: str | None = type(self)._optimized_description - self.execution_stats: "NodeExecutionStats" = NodeExecutionStats() + self.execution_stats: NodeExecutionStats = NodeExecutionStats() if self.webhook_config: if isinstance(self.webhook_config, BlockWebhookConfig): @@ -554,7 +588,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): return data raise ValueError(f"{self.name} did not produce any output for {output}") - def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats": + def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats: self.execution_stats += stats return self.execution_stats @@ -705,11 +739,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): # (e.g. AgentExecutorBlock) get proper input validation. is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False) if is_dry_run: + # Credential fields may be absent (LLM-built agents often skip + # wiring them) or nullified earlier in the pipeline. Validate + # the non-credential inputs against a schema with those fields + # excluded — stripping only the data while keeping them in the + # ``required`` list would falsely report ``'credentials' is a + # required property``. cred_field_names = set(self.input_schema.get_credentials_fields().keys()) - non_cred_data = { - k: v for k, v in input_data.items() if k not in cred_field_names - } - if error := self.input_schema.validate_data(non_cred_data): + if error := self.input_schema.validate_data( + input_data, exclude_fields=cred_field_names + ): raise BlockInputError( message=f"Unable to execute block with invalid input data: {error}", block_name=self.name, diff --git a/autogpt_platform/backend/backend/blocks/ai_condition.py b/autogpt_platform/backend/backend/blocks/ai_condition.py index 6d62d4ab77..db8c023b99 100644 --- a/autogpt_platform/backend/backend/blocks/ai_condition.py +++ b/autogpt_platform/backend/backend/blocks/ai_condition.py @@ -207,6 +207,9 @@ class AIConditionBlock(AIBlockBase): NodeExecutionStats( input_token_count=response.prompt_tokens, output_token_count=response.completion_tokens, + cache_read_token_count=response.cache_read_tokens, + cache_creation_token_count=response.cache_creation_tokens, + provider_cost=response.provider_cost, ) ) self.prompt = response.prompt diff --git a/autogpt_platform/backend/backend/blocks/ai_condition_test.py b/autogpt_platform/backend/backend/blocks/ai_condition_test.py index babb1eb4cf..5520963682 100644 --- a/autogpt_platform/backend/backend/blocks/ai_condition_test.py +++ b/autogpt_platform/backend/backend/blocks/ai_condition_test.py @@ -47,7 +47,13 @@ def _make_input(**overrides) -> AIConditionBlock.Input: return AIConditionBlock.Input(**defaults) -def _mock_llm_response(response_text: str) -> LLMResponse: +def _mock_llm_response( + response_text: str, + *, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + provider_cost: float | None = None, +) -> LLMResponse: return LLMResponse( raw_response="", prompt=[], @@ -56,6 +62,9 @@ def _mock_llm_response(response_text: str) -> LLMResponse: prompt_tokens=10, completion_tokens=5, reasoning=None, + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, + provider_cost=provider_cost, ) @@ -145,3 +154,35 @@ class TestExceptionPropagation: input_data = _make_input() with pytest.raises(RuntimeError, match="LLM provider error"): await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS) + + +# --------------------------------------------------------------------------- +# Regression: cache tokens and provider_cost must be propagated to stats +# --------------------------------------------------------------------------- + + +class TestCacheTokenPropagation: + @pytest.mark.asyncio + async def test_cache_tokens_propagated_to_stats( + self, monkeypatch: pytest.MonkeyPatch + ): + """cache_read_tokens and cache_creation_tokens must be forwarded to + NodeExecutionStats so that usage dashboards count cached tokens.""" + block = AIConditionBlock() + + async def spy_llm(**kwargs): + return _mock_llm_response( + "true", + cache_read_tokens=7, + cache_creation_tokens=3, + provider_cost=0.0012, + ) + + monkeypatch.setattr(block, "llm_call", spy_llm) + + input_data = _make_input() + await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS) + + assert block.execution_stats.cache_read_token_count == 7 + assert block.execution_stats.cache_creation_token_count == 3 + assert block.execution_stats.provider_cost == 0.0012 diff --git a/autogpt_platform/backend/backend/blocks/apollo/organization.py b/autogpt_platform/backend/backend/blocks/apollo/organization.py index 6722de4a79..66b87ca6b9 100644 --- a/autogpt_platform/backend/backend/blocks/apollo/organization.py +++ b/autogpt_platform/backend/backend/blocks/apollo/organization.py @@ -17,7 +17,7 @@ from backend.blocks.apollo.models import ( PrimaryPhone, SearchOrganizationsRequest, ) -from backend.data.model import CredentialsField, SchemaField +from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField class SearchOrganizationsBlock(Block): @@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint ) -> BlockOutput: query = SearchOrganizationsRequest(**input_data.model_dump()) organizations = await self.search_organizations(query, credentials) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(organizations)), provider_cost_type="items" + ) + ) for organization in organizations: yield "organization", organization yield "organizations", organizations diff --git a/autogpt_platform/backend/backend/blocks/apollo/people.py b/autogpt_platform/backend/backend/blocks/apollo/people.py index b5059a2a26..5d4f3c22ec 100644 --- a/autogpt_platform/backend/backend/blocks/apollo/people.py +++ b/autogpt_platform/backend/backend/blocks/apollo/people.py @@ -21,7 +21,7 @@ from backend.blocks.apollo.models import ( SearchPeopleRequest, SenorityLevels, ) -from backend.data.model import CredentialsField, SchemaField +from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField class SearchPeopleBlock(Block): @@ -366,4 +366,9 @@ class SearchPeopleBlock(Block): *(enrich_or_fallback(person) for person in people) ) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(people)), provider_cost_type="items" + ) + ) yield "people", people diff --git a/autogpt_platform/backend/backend/blocks/autopilot.py b/autogpt_platform/backend/backend/blocks/autopilot.py index d479169c94..ff7c3784ac 100644 --- a/autogpt_platform/backend/backend/blocks/autopilot.py +++ b/autogpt_platform/backend/backend/blocks/autopilot.py @@ -4,6 +4,7 @@ import asyncio import contextvars import json import logging +import uuid from typing import TYPE_CHECKING, Any from typing_extensions import TypedDict # Needed for Python <3.12 compatibility @@ -22,6 +23,7 @@ from backend.copilot.permissions import ( validate_block_identifiers, ) from backend.data.model import SchemaField +from backend.util.exceptions import BlockExecutionError if TYPE_CHECKING: from backend.data.execution import ExecutionContext @@ -31,6 +33,37 @@ logger = logging.getLogger(__name__) # Block ID shared between autopilot.py and copilot prompting.py. AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6" +# Identifiers used when registering an AutoPilotBlock turn with the +# stream registry — distinguishes block-originated turns from sub-session +# or HTTP SSE turns in logs / observability. +_AUTOPILOT_TOOL_CALL_ID = "autopilot_block" +_AUTOPILOT_TOOL_NAME = "autopilot_block" + +# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the +# enqueued turn's terminal event. Graph blocks run synchronously from +# the caller's perspective so we wait effectively as long as needed; 6h +# matches the previous abandoned-task cap and is much longer than any +# legitimate AutoPilot turn. +_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours + + +class SubAgentRecursionError(BlockExecutionError): + """Raised when the AutoPilot sub-agent nesting depth limit is exceeded. + + Inherits :class:`BlockExecutionError` — this is a known, handled + runtime failure at the block level (caller nested AutoPilotBlocks + beyond the configured limit). Surfaces with the block_name / + block_id the block framework expects, instead of being wrapped in + ``BlockUnknownError``. + """ + + def __init__(self, message: str) -> None: + super().__init__( + message=message, + block_name="AutoPilotBlock", + block_id=AUTOPILOT_BLOCK_ID, + ) + class ToolCallEntry(TypedDict): """A single tool invocation record from an autopilot execution.""" @@ -263,11 +296,15 @@ class AutoPilotBlock(Block): user_id: str, permissions: "CopilotPermissions | None" = None, ) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]: - """Invoke the copilot and collect all stream results. + """Invoke the copilot on the copilot_executor queue and aggregate the + result. - Delegates to :func:`collect_copilot_response` — the shared helper that - consumes ``stream_chat_completion_sdk`` without wrapping it in an - ``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts). + Delegates to :func:`run_copilot_turn_via_queue` — the shared + primitive used by ``run_sub_session`` too — which creates the + stream_registry meta record, enqueues the job, and waits on the + Redis stream for the terminal event. Any available + copilot_executor worker picks up the job, so this call survives + the graph-executor worker dying mid-turn (RabbitMQ redelivers). Args: prompt: The user task/instruction. @@ -280,8 +317,8 @@ class AutoPilotBlock(Block): Returns: A tuple of (response_text, tool_calls, history_json, session_id, usage). """ - from backend.copilot.sdk.collect import ( - collect_copilot_response, # avoid circular import + from backend.copilot.sdk.session_waiter import ( + run_copilot_turn_via_queue, # avoid circular import ) tokens = _check_recursion(max_recursion_depth) @@ -294,14 +331,35 @@ class AutoPilotBlock(Block): if system_context: effective_prompt = f"[System Context: {system_context}]\n\n{prompt}" - result = await collect_copilot_response( + outcome, result = await run_copilot_turn_via_queue( session_id=session_id, - message=effective_prompt, user_id=user_id, + message=effective_prompt, + # Graph block execution is synchronous from the caller's + # perspective — wait effectively as long as needed. The + # SDK enforces its own idle-based timeout inside the + # stream_registry pipeline. + timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS, permissions=effective_permissions, + tool_call_id=_AUTOPILOT_TOOL_CALL_ID, + tool_name=_AUTOPILOT_TOOL_NAME, ) + if outcome == "failed": + raise RuntimeError( + "AutoPilot turn failed — see the session's transcript" + ) + if outcome == "running": + raise RuntimeError( + "AutoPilot turn did not complete within " + f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session " + f"{session_id}" + ) - # Build a lightweight conversation summary from streamed data. + # Build a lightweight conversation summary from the aggregated data. + # When ``result.queued`` is True the prompt rode on an already- + # in-flight turn (``run_copilot_turn_via_queue`` queued it and + # waited on the existing turn's stream); the aggregated result + # is still valid, so the same rendering path applies. turn_messages: list[dict[str, Any]] = [ {"role": "user", "content": effective_prompt}, ] @@ -310,7 +368,7 @@ class AutoPilotBlock(Block): { "role": "assistant", "content": result.response_text, - "tool_calls": result.tool_calls, + "tool_calls": [tc.model_dump() for tc in result.tool_calls], } ) else: @@ -321,11 +379,11 @@ class AutoPilotBlock(Block): tool_calls: list[ToolCallEntry] = [ { - "tool_call_id": tc["tool_call_id"], - "tool_name": tc["tool_name"], - "input": tc["input"], - "output": tc["output"], - "success": tc["success"], + "tool_call_id": tc.tool_call_id, + "tool_name": tc.tool_name, + "input": tc.input, + "output": tc.output, + "success": tc.success, } for tc in result.tool_calls ] @@ -383,7 +441,8 @@ class AutoPilotBlock(Block): sid = input_data.session_id if not sid: sid = await self.create_session( - execution_context.user_id, dry_run=input_data.dry_run + execution_context.user_id, + dry_run=input_data.dry_run or execution_context.dry_run, ) # NOTE: No asyncio.timeout() here — the SDK manages its own @@ -409,8 +468,41 @@ class AutoPilotBlock(Block): yield "session_id", sid yield "error", "AutoPilot execution was cancelled." raise + except SubAgentRecursionError as exc: + # Deliberate block — re-enqueueing would immediately hit the limit + # again, so skip recovery and just surface the error. + yield "session_id", sid + yield "error", str(exc) except Exception as exc: yield "session_id", sid + # Recovery enqueue must happen BEFORE yielding "error": the block + # framework (_base.execute) raises BlockExecutionError immediately + # when it sees ("error", ...) and stops consuming the generator, + # so any code after that yield is dead code in production. + effective_prompt = input_data.prompt + if input_data.system_context: + effective_prompt = ( + f"[System Context: {input_data.system_context}]\n\n" + f"{input_data.prompt}" + ) + try: + await _enqueue_for_recovery( + sid, + execution_context.user_id, + effective_prompt, + input_data.dry_run or execution_context.dry_run, + ) + except asyncio.CancelledError: + # Task cancelled during recovery — still yield the error + # so the session_id + error pair is visible before re-raising. + yield "error", str(exc) + raise + except Exception: + logger.warning( + "AutoPilot session %s: recovery enqueue raised unexpectedly", + sid[:12], + exc_info=True, + ) yield "error", str(exc) @@ -438,13 +530,13 @@ def _check_recursion( when the caller exits to restore the previous depth. Raises: - RuntimeError: If the current depth already meets or exceeds the limit. + SubAgentRecursionError: If the current depth already meets or exceeds the limit. """ current = _autopilot_recursion_depth.get() inherited = _autopilot_recursion_limit.get() limit = max_depth if inherited is None else min(inherited, max_depth) if current >= limit: - raise RuntimeError( + raise SubAgentRecursionError( f"AutoPilot recursion depth limit reached ({limit}). " "The autopilot has called itself too many times." ) @@ -535,3 +627,51 @@ def _merge_inherited_permissions( # Return the token so the caller can restore the previous value in finally. token = _inherited_permissions.set(merged) return merged, token + + +# --------------------------------------------------------------------------- +# Recovery helpers +# --------------------------------------------------------------------------- + + +async def _enqueue_for_recovery( + session_id: str, + user_id: str, + message: str, + dry_run: bool, +) -> None: + """Re-enqueue an orphaned sub-agent session so a fresh executor picks it up. + + When ``execute_copilot`` raises an unexpected exception the sub-agent + session is left with ``last_role=user`` and no active consumer — identical + to the state that caused Toran's reports of silent sub-agents. Publishing + the original prompt back to the copilot queue lets the executor service + resume the session without manual intervention. + + Skipped for dry-run sessions (no real consumers listen to the queue for + simulated sessions). Any failure to publish is logged and swallowed so + it never masks the original exception. + """ + if dry_run: + return + try: + from backend.copilot.executor.utils import ( # avoid circular import + enqueue_copilot_turn, + ) + + await asyncio.wait_for( + enqueue_copilot_turn( + session_id=session_id, + user_id=user_id, + message=message, + turn_id=str(uuid.uuid4()), + ), + timeout=10, + ) + logger.info("AutoPilot session %s enqueued for recovery", session_id[:12]) + except Exception: + logger.warning( + "AutoPilot session %s: failed to enqueue for recovery", + session_id[:12], + exc_info=True, + ) diff --git a/autogpt_platform/backend/backend/blocks/block_cost_tracking_test.py b/autogpt_platform/backend/backend/blocks/block_cost_tracking_test.py new file mode 100644 index 0000000000..45db11b717 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/block_cost_tracking_test.py @@ -0,0 +1,712 @@ +"""Unit tests for merge_stats cost tracking in individual blocks. + +Covers the exa code_context, exa contents, and apollo organization blocks +to verify provider cost is correctly extracted and reported. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.data.model import APIKeyCredentials, NodeExecutionStats + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TEST_EXA_CREDENTIALS = APIKeyCredentials( + id="01234567-89ab-cdef-0123-456789abcdef", + provider="exa", + api_key=SecretStr("mock-exa-api-key"), + title="Mock Exa API key", + expires_at=None, +) + +TEST_EXA_CREDENTIALS_INPUT = { + "provider": TEST_EXA_CREDENTIALS.provider, + "id": TEST_EXA_CREDENTIALS.id, + "type": TEST_EXA_CREDENTIALS.type, + "title": TEST_EXA_CREDENTIALS.title, +} + + +# --------------------------------------------------------------------------- +# ExaCodeContextBlock — cost_dollars is a string like "0.005" +# --------------------------------------------------------------------------- + + +class TestExaCodeContextBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_float_cost(self): + """float(cost_dollars) parsed from API string and passed to merge_stats.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + + api_response = { + "requestId": "req-1", + "query": "how to use hooks", + "response": "Here are some examples...", + "resultsCount": 3, + "costDollars": "0.005", + "searchTime": 1.2, + "outputTokens": 100, + } + + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.exa.code_context.Requests.post", + new_callable=AsyncMock, + return_value=mock_resp, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = ExaCodeContextBlock.Input( + query="how to use hooks", + credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type] + ) + results = [] + async for output in block.run( + input_data, + credentials=TEST_EXA_CREDENTIALS, + ): + results.append(output) + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_invalid_cost_dollars_does_not_raise(self): + """When cost_dollars cannot be parsed as float, merge_stats is not called.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + + api_response = { + "requestId": "req-2", + "query": "query", + "response": "response", + "resultsCount": 0, + "costDollars": "N/A", + "searchTime": 0.5, + "outputTokens": 0, + } + + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + + merge_calls: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.exa.code_context.Requests.post", + new_callable=AsyncMock, + return_value=mock_resp, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: merge_calls.append(s) + ), + ): + input_data = ExaCodeContextBlock.Input( + query="query", + credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run( + input_data, + credentials=TEST_EXA_CREDENTIALS, + ): + pass + + assert merge_calls == [] + + @pytest.mark.asyncio + async def test_zero_cost_is_tracked(self): + """A zero cost_dollars string '0.0' should still be recorded.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + + api_response = { + "requestId": "req-3", + "query": "query", + "response": "...", + "resultsCount": 1, + "costDollars": "0.0", + "searchTime": 0.1, + "outputTokens": 10, + } + + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.exa.code_context.Requests.post", + new_callable=AsyncMock, + return_value=mock_resp, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = ExaCodeContextBlock.Input( + query="query", + credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run( + input_data, + credentials=TEST_EXA_CREDENTIALS, + ): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 0.0 + + +# --------------------------------------------------------------------------- +# ExaContentsBlock — response.cost_dollars.total (CostDollars model) +# --------------------------------------------------------------------------- + + +class TestExaContentsBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_cost_dollars_total(self): + """provider_cost equals response.cost_dollars.total when present.""" + from backend.blocks.exa.contents import ExaContentsBlock + from backend.blocks.exa.helpers import CostDollars + + block = ExaContentsBlock() + + cost_dollars = CostDollars(total=0.012) + + mock_response = MagicMock() + mock_response.results = [] + mock_response.context = None + mock_response.statuses = None + mock_response.cost_dollars = cost_dollars + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.exa.contents.AsyncExa", + return_value=MagicMock( + get_contents=AsyncMock(return_value=mock_response) + ), + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = ExaContentsBlock.Input( + urls=["https://example.com"], + credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run( + input_data, + credentials=TEST_EXA_CREDENTIALS, + ): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == pytest.approx(0.012) + + @pytest.mark.asyncio + async def test_no_merge_stats_when_cost_dollars_absent(self): + """When response.cost_dollars is None, merge_stats is not called.""" + from backend.blocks.exa.contents import ExaContentsBlock + + block = ExaContentsBlock() + + mock_response = MagicMock() + mock_response.results = [] + mock_response.context = None + mock_response.statuses = None + mock_response.cost_dollars = None + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.exa.contents.AsyncExa", + return_value=MagicMock( + get_contents=AsyncMock(return_value=mock_response) + ), + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = ExaContentsBlock.Input( + urls=["https://example.com"], + credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run( + input_data, + credentials=TEST_EXA_CREDENTIALS, + ): + pass + + assert accumulated == [] + + +# --------------------------------------------------------------------------- +# SearchOrganizationsBlock — provider_cost = float(len(organizations)) +# --------------------------------------------------------------------------- + + +class TestSearchOrganizationsBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_org_count(self): + """provider_cost == number of returned organizations, type == 'items'.""" + from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS + from backend.blocks.apollo._auth import ( + TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT, + ) + from backend.blocks.apollo.models import Organization + from backend.blocks.apollo.organization import SearchOrganizationsBlock + + block = SearchOrganizationsBlock() + + fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)] + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + SearchOrganizationsBlock, + "search_organizations", + new_callable=AsyncMock, + return_value=fake_orgs, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = SearchOrganizationsBlock.Input( + credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type] + ) + results = [] + async for output in block.run( + input_data, + credentials=APOLLO_CREDS, + ): + results.append(output) + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == pytest.approx(3.0) + assert accumulated[0].provider_cost_type == "items" + + @pytest.mark.asyncio + async def test_empty_org_list_tracks_zero(self): + """An empty organization list results in provider_cost=0.0.""" + from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS + from backend.blocks.apollo._auth import ( + TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT, + ) + from backend.blocks.apollo.organization import SearchOrganizationsBlock + + block = SearchOrganizationsBlock() + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + SearchOrganizationsBlock, + "search_organizations", + new_callable=AsyncMock, + return_value=[], + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = SearchOrganizationsBlock.Input( + credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run( + input_data, + credentials=APOLLO_CREDS, + ): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 0.0 + assert accumulated[0].provider_cost_type == "items" + + +# --------------------------------------------------------------------------- +# JinaEmbeddingBlock — token count from usage.total_tokens +# --------------------------------------------------------------------------- + + +class TestJinaEmbeddingBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_token_count(self): + """provider token count is recorded when API returns usage.total_tokens.""" + from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS + from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT + from backend.blocks.jina.embeddings import JinaEmbeddingBlock + + block = JinaEmbeddingBlock() + + api_response = { + "data": [{"embedding": [0.1, 0.2, 0.3]}], + "usage": {"total_tokens": 42}, + } + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.jina.embeddings.Requests.post", + new_callable=AsyncMock, + return_value=mock_resp, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = JinaEmbeddingBlock.Input( + texts=["hello world"], + credentials=JINA_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=JINA_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].input_token_count == 42 + + @pytest.mark.asyncio + async def test_no_merge_stats_when_usage_absent(self): + """When API response omits usage field, merge_stats is not called.""" + from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS + from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT + from backend.blocks.jina.embeddings import JinaEmbeddingBlock + + block = JinaEmbeddingBlock() + + api_response = { + "data": [{"embedding": [0.1, 0.2, 0.3]}], + } + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + + accumulated: list[NodeExecutionStats] = [] + + with ( + patch( + "backend.blocks.jina.embeddings.Requests.post", + new_callable=AsyncMock, + return_value=mock_resp, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = JinaEmbeddingBlock.Input( + texts=["hello"], + credentials=JINA_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=JINA_CREDS): + pass + + assert accumulated == [] + + +# --------------------------------------------------------------------------- +# UnrealTextToSpeechBlock — character count from input text length +# --------------------------------------------------------------------------- + + +class TestUnrealTextToSpeechBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_character_count(self): + """provider_cost equals len(text) with type='characters'.""" + from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS + from backend.blocks.text_to_speech_block import ( + TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT, + ) + from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock + + block = UnrealTextToSpeechBlock() + test_text = "Hello, world!" + + with ( + patch.object( + UnrealTextToSpeechBlock, + "call_unreal_speech_api", + new_callable=AsyncMock, + return_value={"OutputUri": "https://example.com/audio.mp3"}, + ), + patch.object(block, "merge_stats") as mock_merge, + ): + input_data = UnrealTextToSpeechBlock.Input( + text=test_text, + credentials=TTS_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=TTS_CREDS): + pass + + mock_merge.assert_called_once() + stats = mock_merge.call_args[0][0] + assert stats.provider_cost == float(len(test_text)) + assert stats.provider_cost_type == "characters" + + @pytest.mark.asyncio + async def test_empty_text_gives_zero_characters(self): + """An empty text string results in provider_cost=0.0.""" + from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS + from backend.blocks.text_to_speech_block import ( + TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT, + ) + from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock + + block = UnrealTextToSpeechBlock() + + with ( + patch.object( + UnrealTextToSpeechBlock, + "call_unreal_speech_api", + new_callable=AsyncMock, + return_value={"OutputUri": "https://example.com/audio.mp3"}, + ), + patch.object(block, "merge_stats") as mock_merge, + ): + input_data = UnrealTextToSpeechBlock.Input( + text="", + credentials=TTS_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=TTS_CREDS): + pass + + mock_merge.assert_called_once() + stats = mock_merge.call_args[0][0] + assert stats.provider_cost == 0.0 + assert stats.provider_cost_type == "characters" + + +# --------------------------------------------------------------------------- +# GoogleMapsSearchBlock — item count from search_places results +# --------------------------------------------------------------------------- + + +class TestGoogleMapsSearchBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_place_count(self): + """provider_cost equals number of returned places, type == 'items'.""" + from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS + from backend.blocks.google_maps import ( + TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT, + ) + from backend.blocks.google_maps import GoogleMapsSearchBlock + + block = GoogleMapsSearchBlock() + + fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)] + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + GoogleMapsSearchBlock, + "search_places", + return_value=fake_places, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = GoogleMapsSearchBlock.Input( + query="coffee shops", + credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=MAPS_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 4.0 + assert accumulated[0].provider_cost_type == "items" + + @pytest.mark.asyncio + async def test_empty_results_tracks_zero(self): + """Zero places returned results in provider_cost=0.0.""" + from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS + from backend.blocks.google_maps import ( + TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT, + ) + from backend.blocks.google_maps import GoogleMapsSearchBlock + + block = GoogleMapsSearchBlock() + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + GoogleMapsSearchBlock, + "search_places", + return_value=[], + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = GoogleMapsSearchBlock.Input( + query="nothing here", + credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=MAPS_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 0.0 + assert accumulated[0].provider_cost_type == "items" + + +# --------------------------------------------------------------------------- +# SmartLeadAddLeadsBlock — item count from lead_list length +# --------------------------------------------------------------------------- + + +class TestSmartLeadAddLeadsBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_lead_count(self): + """provider_cost equals number of leads uploaded, type == 'items'.""" + from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS + from backend.blocks.smartlead._auth import ( + TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT, + ) + from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock + from backend.blocks.smartlead.models import ( + AddLeadsToCampaignResponse, + LeadInput, + ) + + block = AddLeadToCampaignBlock() + + fake_leads = [ + LeadInput(first_name="Alice", last_name="A", email="alice@example.com"), + LeadInput(first_name="Bob", last_name="B", email="bob@example.com"), + ] + fake_response = AddLeadsToCampaignResponse( + ok=True, + upload_count=2, + total_leads=2, + block_count=0, + duplicate_count=0, + invalid_email_count=0, + invalid_emails=[], + already_added_to_campaign=0, + unsubscribed_leads=[], + is_lead_limit_exhausted=False, + lead_import_stopped_count=0, + bounce_count=0, + ) + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + AddLeadToCampaignBlock, + "add_leads_to_campaign", + new_callable=AsyncMock, + return_value=fake_response, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = AddLeadToCampaignBlock.Input( + campaign_id=123, + lead_list=fake_leads, + credentials=SL_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=SL_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 2.0 + assert accumulated[0].provider_cost_type == "items" + + +# --------------------------------------------------------------------------- +# SearchPeopleBlock — item count from people list length +# --------------------------------------------------------------------------- + + +class TestSearchPeopleBlockCostTracking: + @pytest.mark.asyncio + async def test_merge_stats_called_with_people_count(self): + """provider_cost equals number of returned people, type == 'items'.""" + from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS + from backend.blocks.apollo._auth import ( + TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT, + ) + from backend.blocks.apollo.models import Contact + from backend.blocks.apollo.people import SearchPeopleBlock + + block = SearchPeopleBlock() + fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)] + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + SearchPeopleBlock, + "search_people", + new_callable=AsyncMock, + return_value=fake_people, + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = SearchPeopleBlock.Input( + credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=APOLLO_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == pytest.approx(5.0) + assert accumulated[0].provider_cost_type == "items" + + @pytest.mark.asyncio + async def test_empty_people_list_tracks_zero(self): + """An empty people list results in provider_cost=0.0.""" + from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS + from backend.blocks.apollo._auth import ( + TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT, + ) + from backend.blocks.apollo.people import SearchPeopleBlock + + block = SearchPeopleBlock() + accumulated: list[NodeExecutionStats] = [] + + with ( + patch.object( + SearchPeopleBlock, + "search_people", + new_callable=AsyncMock, + return_value=[], + ), + patch.object( + block, "merge_stats", side_effect=lambda s: accumulated.append(s) + ), + ): + input_data = SearchPeopleBlock.Input( + credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type] + ) + async for _ in block.run(input_data, credentials=APOLLO_CREDS): + pass + + assert len(accumulated) == 1 + assert accumulated[0].provider_cost == 0.0 + assert accumulated[0].provider_cost_type == "items" diff --git a/autogpt_platform/backend/backend/blocks/exa/code_context.py b/autogpt_platform/backend/backend/blocks/exa/code_context.py index 962d13fdfa..2855c1dc4a 100644 --- a/autogpt_platform/backend/backend/blocks/exa/code_context.py +++ b/autogpt_platform/backend/backend/blocks/exa/code_context.py @@ -9,6 +9,7 @@ from typing import Union from pydantic import BaseModel +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block): yield "cost_dollars", context.cost_dollars yield "search_time", context.search_time yield "output_tokens", context.output_tokens + + # Parse cost_dollars (API returns as string, e.g. "0.005") + try: + cost_usd = float(context.cost_dollars) + self.merge_stats(NodeExecutionStats(provider_cost=cost_usd)) + except (ValueError, TypeError): + pass diff --git a/autogpt_platform/backend/backend/blocks/exa/contents.py b/autogpt_platform/backend/backend/blocks/exa/contents.py index 9ab854fa85..8b2deaf036 100644 --- a/autogpt_platform/backend/backend/blocks/exa/contents.py +++ b/autogpt_platform/backend/backend/blocks/exa/contents.py @@ -4,6 +4,7 @@ from typing import Optional from exa_py import AsyncExa from pydantic import BaseModel +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -223,3 +224,6 @@ class ExaContentsBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars + self.merge_stats( + NodeExecutionStats(provider_cost=response.cost_dollars.total) + ) diff --git a/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py b/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py new file mode 100644 index 0000000000..1ee395e539 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py @@ -0,0 +1,575 @@ +"""Tests for cost tracking in Exa blocks. + +Covers the cost_dollars → provider_cost → merge_stats path for both +ExaContentsBlock and ExaCodeContextBlock. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT +from backend.data.model import NodeExecutionStats + + +class TestExaCodeContextCostTracking: + """ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats.""" + + @pytest.mark.asyncio + async def test_valid_cost_string_is_parsed_and_merged(self): + """A numeric cost string like '0.005' is merged as provider_cost.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + api_response = { + "requestId": "req-1", + "query": "test query", + "response": "some code", + "resultsCount": 3, + "costDollars": "0.005", + "searchTime": 1.2, + "outputTokens": 100, + } + + with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls: + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp) + + outputs = [] + async for key, value in block.run( + block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + outputs.append((key, value)) + + assert any(k == "cost_dollars" for k, _ in outputs) + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_invalid_cost_string_does_not_raise(self): + """A non-numeric cost_dollars value is swallowed silently.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + api_response = { + "requestId": "req-2", + "query": "test", + "response": "code", + "resultsCount": 0, + "costDollars": "N/A", + "searchTime": 0.5, + "outputTokens": 0, + } + + with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls: + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp) + + outputs = [] + async for key, value in block.run( + block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + outputs.append((key, value)) + + # No merge_stats call because float() raised ValueError + assert len(merged) == 0 + + @pytest.mark.asyncio + async def test_zero_cost_string_is_merged(self): + """'0.0' is a valid cost — should still be tracked.""" + from backend.blocks.exa.code_context import ExaCodeContextBlock + + block = ExaCodeContextBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + api_response = { + "requestId": "req-3", + "query": "free query", + "response": "result", + "resultsCount": 1, + "costDollars": "0.0", + "searchTime": 0.1, + "outputTokens": 10, + } + + with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls: + mock_resp = MagicMock() + mock_resp.json.return_value = api_response + mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp) + + async for _ in block.run( + block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.0) + + +class TestExaContentsCostTracking: + """ExaContentsBlock merges cost_dollars.total as provider_cost.""" + + @pytest.mark.asyncio + async def test_cost_dollars_total_is_merged(self): + """When the SDK response includes cost_dollars, its total is merged.""" + from backend.blocks.exa.contents import ExaContentsBlock + from backend.blocks.exa.helpers import CostDollars + + block = ExaContentsBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.statuses = None + mock_sdk_response.cost_dollars = CostDollars(total=0.012) + + with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.012) + + @pytest.mark.asyncio + async def test_no_cost_dollars_skips_merge(self): + """When cost_dollars is absent, merge_stats is not called.""" + from backend.blocks.exa.contents import ExaContentsBlock + + block = ExaContentsBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.statuses = None + mock_sdk_response.cost_dollars = None + + with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 0 + + @pytest.mark.asyncio + async def test_zero_cost_dollars_is_merged(self): + """A total of 0.0 (free tier) should still be merged.""" + from backend.blocks.exa.contents import ExaContentsBlock + from backend.blocks.exa.helpers import CostDollars + + block = ExaContentsBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.statuses = None + mock_sdk_response.cost_dollars = CostDollars(total=0.0) + + with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.0) + + +class TestExaSearchCostTracking: + """ExaSearchBlock merges cost_dollars.total as provider_cost.""" + + @pytest.mark.asyncio + async def test_cost_dollars_total_is_merged(self): + """When the SDK response includes cost_dollars, its total is merged.""" + from backend.blocks.exa.helpers import CostDollars + from backend.blocks.exa.search import ExaSearchBlock + + block = ExaSearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.resolved_search_type = None + mock_sdk_response.cost_dollars = CostDollars(total=0.008) + + with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.search = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.008) + + @pytest.mark.asyncio + async def test_no_cost_dollars_skips_merge(self): + """When cost_dollars is absent, merge_stats is not called.""" + from backend.blocks.exa.search import ExaSearchBlock + + block = ExaSearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.resolved_search_type = None + mock_sdk_response.cost_dollars = None + + with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.search = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 0 + + +class TestExaSimilarCostTracking: + """ExaFindSimilarBlock merges cost_dollars.total as provider_cost.""" + + @pytest.mark.asyncio + async def test_cost_dollars_total_is_merged(self): + """When the SDK response includes cost_dollars, its total is merged.""" + from backend.blocks.exa.helpers import CostDollars + from backend.blocks.exa.similar import ExaFindSimilarBlock + + block = ExaFindSimilarBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.request_id = "req-1" + mock_sdk_response.cost_dollars = CostDollars(total=0.015) + + with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.015) + + @pytest.mark.asyncio + async def test_no_cost_dollars_skips_merge(self): + """When cost_dollars is absent, merge_stats is not called.""" + from backend.blocks.exa.similar import ExaFindSimilarBlock + + block = ExaFindSimilarBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + mock_sdk_response = MagicMock() + mock_sdk_response.results = [] + mock_sdk_response.context = None + mock_sdk_response.request_id = "req-2" + mock_sdk_response.cost_dollars = None + + with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls: + mock_exa = MagicMock() + mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response) + mock_exa_cls.return_value = mock_exa + + async for _ in block.run( + block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 0 + + +# --------------------------------------------------------------------------- +# ExaCreateResearchBlock — cost_dollars from completed poll response +# --------------------------------------------------------------------------- + + +COMPLETED_RESEARCH_RESPONSE = { + "researchId": "test-research-id", + "status": "completed", + "model": "exa-research", + "instructions": "test instructions", + "createdAt": 1700000000000, + "finishedAt": 1700000060000, + "costDollars": { + "total": 0.05, + "numSearches": 3, + "numPages": 10, + "reasoningTokens": 500, + }, + "output": {"content": "Research findings...", "parsed": None}, +} + +PENDING_RESEARCH_RESPONSE = { + "researchId": "test-research-id", + "status": "pending", + "model": "exa-research", + "instructions": "test instructions", + "createdAt": 1700000000000, +} + + +class TestExaCreateResearchBlockCostTracking: + """ExaCreateResearchBlock merges cost from completed poll response.""" + + @pytest.mark.asyncio + async def test_cost_merged_when_research_completes(self): + """merge_stats called with provider_cost=total when poll returns completed.""" + from backend.blocks.exa.research import ExaCreateResearchBlock + + block = ExaCreateResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + create_resp = MagicMock() + create_resp.json.return_value = PENDING_RESEARCH_RESPONSE + + poll_resp = MagicMock() + poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE + + mock_instance = MagicMock() + mock_instance.post = AsyncMock(return_value=create_resp) + mock_instance.get = AsyncMock(return_value=poll_resp) + + with ( + patch("backend.blocks.exa.research.Requests", return_value=mock_instance), + patch("asyncio.sleep", new=AsyncMock()), + ): + async for _ in block.run( + block.Input( + instructions="test instructions", + wait_for_completion=True, + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.05) + + @pytest.mark.asyncio + async def test_no_merge_when_no_cost_dollars(self): + """When completed response has no costDollars, merge_stats is not called.""" + from backend.blocks.exa.research import ExaCreateResearchBlock + + block = ExaCreateResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None} + create_resp = MagicMock() + create_resp.json.return_value = PENDING_RESEARCH_RESPONSE + poll_resp = MagicMock() + poll_resp.json.return_value = no_cost_response + + mock_instance = MagicMock() + mock_instance.post = AsyncMock(return_value=create_resp) + mock_instance.get = AsyncMock(return_value=poll_resp) + + with ( + patch("backend.blocks.exa.research.Requests", return_value=mock_instance), + patch("asyncio.sleep", new=AsyncMock()), + ): + async for _ in block.run( + block.Input( + instructions="test instructions", + wait_for_completion=True, + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert merged == [] + + +# --------------------------------------------------------------------------- +# ExaGetResearchBlock — cost_dollars from single GET response +# --------------------------------------------------------------------------- + + +class TestExaGetResearchBlockCostTracking: + """ExaGetResearchBlock merges cost when the fetched research has cost_dollars.""" + + @pytest.mark.asyncio + async def test_cost_merged_from_completed_research(self): + """merge_stats called with provider_cost=total when research has costDollars.""" + from backend.blocks.exa.research import ExaGetResearchBlock + + block = ExaGetResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + get_resp = MagicMock() + get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE + + mock_instance = MagicMock() + mock_instance.get = AsyncMock(return_value=get_resp) + + with patch("backend.blocks.exa.research.Requests", return_value=mock_instance): + async for _ in block.run( + block.Input( + research_id="test-research-id", + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.05) + + @pytest.mark.asyncio + async def test_no_merge_when_no_cost_dollars(self): + """When research has no costDollars, merge_stats is not called.""" + from backend.blocks.exa.research import ExaGetResearchBlock + + block = ExaGetResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None} + get_resp = MagicMock() + get_resp.json.return_value = no_cost_response + + mock_instance = MagicMock() + mock_instance.get = AsyncMock(return_value=get_resp) + + with patch("backend.blocks.exa.research.Requests", return_value=mock_instance): + async for _ in block.run( + block.Input( + research_id="test-research-id", + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert merged == [] + + +# --------------------------------------------------------------------------- +# ExaWaitForResearchBlock — cost_dollars from polling response +# --------------------------------------------------------------------------- + + +class TestExaWaitForResearchBlockCostTracking: + """ExaWaitForResearchBlock merges cost when the polled research has cost_dollars.""" + + @pytest.mark.asyncio + async def test_cost_merged_when_research_completes(self): + """merge_stats called with provider_cost=total once polling returns completed.""" + from backend.blocks.exa.research import ExaWaitForResearchBlock + + block = ExaWaitForResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + poll_resp = MagicMock() + poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE + + mock_instance = MagicMock() + mock_instance.get = AsyncMock(return_value=poll_resp) + + with ( + patch("backend.blocks.exa.research.Requests", return_value=mock_instance), + patch("asyncio.sleep", new=AsyncMock()), + ): + async for _ in block.run( + block.Input( + research_id="test-research-id", + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert len(merged) == 1 + assert merged[0].provider_cost == pytest.approx(0.05) + + @pytest.mark.asyncio + async def test_no_merge_when_no_cost_dollars(self): + """When completed research has no costDollars, merge_stats is not called.""" + from backend.blocks.exa.research import ExaWaitForResearchBlock + + block = ExaWaitForResearchBlock() + merged: list[NodeExecutionStats] = [] + block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment] + + no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None} + poll_resp = MagicMock() + poll_resp.json.return_value = no_cost_response + + mock_instance = MagicMock() + mock_instance.get = AsyncMock(return_value=poll_resp) + + with ( + patch("backend.blocks.exa.research.Requests", return_value=mock_instance), + patch("asyncio.sleep", new=AsyncMock()), + ): + async for _ in block.run( + block.Input( + research_id="test-research-id", + credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type] + ), + credentials=TEST_CREDENTIALS, + ): + pass + + assert merged == [] diff --git a/autogpt_platform/backend/backend/blocks/exa/research.py b/autogpt_platform/backend/backend/blocks/exa/research.py index c35a1048df..575a88cc01 100644 --- a/autogpt_platform/backend/backend/blocks/exa/research.py +++ b/autogpt_platform/backend/backend/blocks/exa/research.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block): if research.cost_dollars: yield "cost_total", research.cost_dollars.total + self.merge_stats( + NodeExecutionStats( + provider_cost=research.cost_dollars.total + ) + ) return await asyncio.sleep(check_interval) @@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block): yield "cost_searches", research.cost_dollars.num_searches yield "cost_pages", research.cost_dollars.num_pages yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens + self.merge_stats( + NodeExecutionStats(provider_cost=research.cost_dollars.total) + ) yield "error_message", research.error @@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block): if research.cost_dollars: yield "cost_total", research.cost_dollars.total + self.merge_stats( + NodeExecutionStats(provider_cost=research.cost_dollars.total) + ) return diff --git a/autogpt_platform/backend/backend/blocks/exa/search.py b/autogpt_platform/backend/backend/blocks/exa/search.py index 7e4ccfc538..5d9e99698f 100644 --- a/autogpt_platform/backend/backend/blocks/exa/search.py +++ b/autogpt_platform/backend/backend/blocks/exa/search.py @@ -4,6 +4,7 @@ from typing import Optional from exa_py import AsyncExa +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -206,3 +207,6 @@ class ExaSearchBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars + self.merge_stats( + NodeExecutionStats(provider_cost=response.cost_dollars.total) + ) diff --git a/autogpt_platform/backend/backend/blocks/exa/similar.py b/autogpt_platform/backend/backend/blocks/exa/similar.py index e2c592ff05..004dfec4d6 100644 --- a/autogpt_platform/backend/backend/blocks/exa/similar.py +++ b/autogpt_platform/backend/backend/blocks/exa/similar.py @@ -3,6 +3,7 @@ from typing import Optional from exa_py import AsyncExa +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars + self.merge_stats( + NodeExecutionStats(provider_cost=response.cost_dollars.total) + ) diff --git a/autogpt_platform/backend/backend/blocks/google_maps.py b/autogpt_platform/backend/backend/blocks/google_maps.py index bab0841c5d..8b561d3bd1 100644 --- a/autogpt_platform/backend/backend/blocks/google_maps.py +++ b/autogpt_platform/backend/backend/blocks/google_maps.py @@ -14,6 +14,7 @@ from backend.data.model import ( APIKeyCredentials, CredentialsField, CredentialsMetaInput, + NodeExecutionStats, SchemaField, ) from backend.integrations.providers import ProviderName @@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block): input_data.radius, input_data.max_results, ) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(places)), provider_cost_type="items" + ) + ) for place in places: yield "place", place diff --git a/autogpt_platform/backend/backend/blocks/jina/embeddings.py b/autogpt_platform/backend/backend/blocks/jina/embeddings.py index f787de03b3..88f97f43fb 100644 --- a/autogpt_platform/backend/backend/blocks/jina/embeddings.py +++ b/autogpt_platform/backend/backend/blocks/jina/embeddings.py @@ -10,7 +10,7 @@ from backend.blocks.jina._auth import ( JinaCredentialsField, JinaCredentialsInput, ) -from backend.data.model import SchemaField +from backend.data.model import NodeExecutionStats, SchemaField from backend.util.request import Requests @@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block): } data = {"input": input_data.texts, "model": input_data.model} response = await Requests().post(url, headers=headers, json=data) - embeddings = [e["embedding"] for e in response.json()["data"]] + resp_json = response.json() + embeddings = [e["embedding"] for e in resp_json["data"]] + usage = resp_json.get("usage", {}) + if usage.get("total_tokens"): + self.merge_stats( + NodeExecutionStats( + input_token_count=usage.get("total_tokens", 0), + ) + ) yield "embeddings", embeddings diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index e3e34c9968..8543a03b69 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -1,6 +1,7 @@ # This file contains a lot of prompt block strings that would trigger "line too long" # flake8: noqa: E501 import logging +import math import re import secrets from abc import ABC @@ -13,6 +14,7 @@ import ollama import openai from anthropic.types import ToolParam from groq import AsyncGroq +from openai.types.chat import ChatCompletion as OpenAIChatCompletion from pydantic import BaseModel, SecretStr from backend.blocks._base import ( @@ -104,7 +106,6 @@ class LlmModelMeta(EnumMeta): class LlmModel(str, Enum, metaclass=LlmModelMeta): - @classmethod def _missing_(cls, value: object) -> "LlmModel | None": """Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'.""" @@ -201,10 +202,25 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): GROK_4 = "x-ai/grok-4" GROK_4_FAST = "x-ai/grok-4-fast" GROK_4_1_FAST = "x-ai/grok-4.1-fast" + GROK_4_20 = "x-ai/grok-4.20" + GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent" GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1" KIMI_K2 = "moonshotai/kimi-k2" QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507" QWEN3_CODER = "qwen/qwen3-coder" + # Z.ai (Zhipu) models + ZAI_GLM_4_32B = "z-ai/glm-4-32b" + ZAI_GLM_4_5 = "z-ai/glm-4.5" + ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air" + ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free" + ZAI_GLM_4_5V = "z-ai/glm-4.5v" + ZAI_GLM_4_6 = "z-ai/glm-4.6" + ZAI_GLM_4_6V = "z-ai/glm-4.6v" + ZAI_GLM_4_7 = "z-ai/glm-4.7" + ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash" + ZAI_GLM_5 = "z-ai/glm-5" + ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo" + ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo" # Llama API models LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8" LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8" @@ -612,6 +628,18 @@ MODEL_METADATA = { LlmModel.GROK_4_1_FAST: ModelMetadata( "open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1 ), + LlmModel.GROK_4_20: ModelMetadata( + "open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3 + ), + LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata( + "open_router", + 2000000, + 100000, + "Grok 4.20 Multi-Agent", + "OpenRouter", + "xAI", + 3, + ), LlmModel.GROK_CODE_FAST_1: ModelMetadata( "open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1 ), @@ -630,6 +658,43 @@ MODEL_METADATA = { LlmModel.QWEN3_CODER: ModelMetadata( "open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3 ), + # https://openrouter.ai/models?q=z-ai + LlmModel.ZAI_GLM_4_32B: ModelMetadata( + "open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_5: ModelMetadata( + "open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2 + ), + LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata( + "open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata( + "open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_5V: ModelMetadata( + "open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2 + ), + LlmModel.ZAI_GLM_4_6: ModelMetadata( + "open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_6V: ModelMetadata( + "open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_7: ModelMetadata( + "open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata( + "open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1 + ), + LlmModel.ZAI_GLM_5: ModelMetadata( + "open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2 + ), + LlmModel.ZAI_GLM_5_TURBO: ModelMetadata( + "open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3 + ), + LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata( + "open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3 + ), # Llama API models LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata( "llama_api", @@ -686,17 +751,20 @@ class LLMResponse(BaseModel): tool_calls: Optional[List[ToolContentBlock]] | None prompt_tokens: int completion_tokens: int + cache_read_tokens: int = 0 + cache_creation_tokens: int = 0 reasoning: Optional[str] = None + provider_cost: float | None = None def convert_openai_tool_fmt_to_anthropic( openai_tools: list[dict] | None = None, -) -> Iterable[ToolParam] | anthropic.Omit: +) -> Iterable[ToolParam] | anthropic.NotGiven: """ Convert OpenAI tool format to Anthropic tool format. """ if not openai_tools or len(openai_tools) == 0: - return anthropic.omit + return anthropic.NOT_GIVEN anthropic_tools = [] for tool in openai_tools: @@ -721,6 +789,35 @@ def convert_openai_tool_fmt_to_anthropic( return anthropic_tools +def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None: + """Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response. + + OpenRouter returns the per-request USD cost in a response header. The + OpenAI SDK exposes the raw httpx response via an undocumented `_response` + attribute. We use try/except AttributeError so that if the SDK ever drops + or renames that attribute, the warning is visible in logs rather than + silently degrading to no cost tracking. + """ + try: + raw_resp = response._response # type: ignore[attr-defined] + except AttributeError: + logger.warning( + "OpenAI SDK response missing _response attribute" + " — OpenRouter cost tracking unavailable" + ) + return None + try: + cost_header = raw_resp.headers.get("x-total-cost") + if not cost_header: + return None + cost = float(cost_header) + if not math.isfinite(cost) or cost < 0: + return None + return cost + except (ValueError, TypeError, AttributeError): + return None + + def extract_openai_reasoning(response) -> str | None: """Extract reasoning from OpenAI-compatible response if available.""" """Note: This will likely not working since the reasoning is not present in another Response API""" @@ -803,6 +900,21 @@ async def llm_call( provider = llm_model.metadata.provider context_window = llm_model.context_window + # Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key + # is configured, route direct-Anthropic models through OpenRouter instead. This + # gives us the x-total-cost header for free, so provider_cost is always populated + # without manual token-rate arithmetic. + or_key = settings.secrets.open_router_api_key + or_model_id: str | None = None + if provider == "anthropic" and or_key: + provider = "open_router" + credentials = APIKeyCredentials( + provider=ProviderName.OPEN_ROUTER, + title="OpenRouter (auto)", + api_key=SecretStr(or_key), + ) + or_model_id = f"anthropic/{llm_model.value}" + if compress_prompt_to_fit: result = await compress_context( messages=prompt, @@ -888,8 +1000,12 @@ async def llm_call( reasoning=reasoning, ) elif provider == "anthropic": - an_tools = convert_openai_tool_fmt_to_anthropic(tools) + # Cache tool definitions alongside the system prompt. + # Placing cache_control on the last tool caches all tool schemas as a + # single prefix — reads cost 10% of normal input tokens. + if isinstance(an_tools, list) and an_tools: + an_tools[-1] = {**an_tools[-1], "cache_control": {"type": "ephemeral"}} system_messages = [p["content"] for p in prompt if p["role"] == "system"] sysprompt = " ".join(system_messages) @@ -912,14 +1028,34 @@ async def llm_call( client = anthropic.AsyncAnthropic( api_key=credentials.api_key.get_secret_value() ) - resp = await client.messages.create( + # create_kwargs is built as a plain dict so we can conditionally add + # the `system` field only when the prompt is non-empty. Anthropic's + # API rejects empty text blocks (returns HTTP 400), so omitting the + # field is the correct behaviour for whitespace-only prompts. + create_kwargs: dict[str, Any] = dict( model=llm_model.value, - system=sysprompt, messages=messages, max_tokens=max_tokens, + # `an_tools` may be anthropic.NOT_GIVEN when no tools were + # configured. The SDK treats NOT_GIVEN as a sentinel meaning "omit + # this field from the serialized request", so passing it here is + # equivalent to not including the key at all — no `tools` field is + # sent to the API in that case. tools=an_tools, timeout=600, ) + if sysprompt.strip(): + # Wrap the system prompt in a single cacheable text block. + # The guard intentionally omits `system` for whitespace-only + # prompts — Anthropic rejects empty text blocks with HTTP 400. + create_kwargs["system"] = [ + { + "type": "text", + "text": sysprompt, + "cache_control": {"type": "ephemeral"}, + } + ] + resp = await client.messages.create(**create_kwargs) if not resp.content: raise ValueError("No content returned from Anthropic.") @@ -964,6 +1100,11 @@ async def llm_call( tool_calls=tool_calls, prompt_tokens=resp.usage.input_tokens, completion_tokens=resp.usage.output_tokens, + cache_read_tokens=getattr(resp.usage, "cache_read_input_tokens", None) or 0, + cache_creation_tokens=getattr( + resp.usage, "cache_creation_input_tokens", None + ) + or 0, reasoning=reasoning, ) elif provider == "groq": @@ -1032,7 +1173,7 @@ async def llm_call( "HTTP-Referer": "https://agpt.co", "X-Title": "AutoGPT", }, - model=llm_model.value, + model=or_model_id or llm_model.value, messages=prompt, # type: ignore max_tokens=max_tokens, tools=tools_param, # type: ignore @@ -1053,6 +1194,7 @@ async def llm_call( prompt_tokens=response.usage.prompt_tokens if response.usage else 0, completion_tokens=response.usage.completion_tokens if response.usage else 0, reasoning=reasoning, + provider_cost=extract_openrouter_cost(response), ) elif provider == "llama_api": tools_param = tools if tools else openai.NOT_GIVEN @@ -1360,6 +1502,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): error_feedback_message = "" llm_model = input_data.model + total_provider_cost: float | None = None for retry_count in range(input_data.retry): logger.debug(f"LLM request: {prompt}") @@ -1377,12 +1520,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): max_tokens=input_data.max_tokens, ) response_text = llm_response.response - self.merge_stats( - NodeExecutionStats( - input_token_count=llm_response.prompt_tokens, - output_token_count=llm_response.completion_tokens, - ) + # Accumulate token counts and provider_cost for every attempt + # (each call costs tokens and USD, regardless of validation outcome). + token_stats = NodeExecutionStats( + input_token_count=llm_response.prompt_tokens, + output_token_count=llm_response.completion_tokens, + cache_read_token_count=llm_response.cache_read_tokens, + cache_creation_token_count=llm_response.cache_creation_tokens, ) + self.merge_stats(token_stats) + if llm_response.provider_cost is not None: + total_provider_cost = ( + total_provider_cost or 0.0 + ) + llm_response.provider_cost logger.debug(f"LLM attempt-{retry_count} response: {response_text}") if input_data.expected_format: @@ -1451,6 +1601,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): NodeExecutionStats( llm_call_count=retry_count + 1, llm_retry_count=retry_count, + provider_cost=total_provider_cost, ) ) yield "response", response_obj @@ -1471,6 +1622,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): NodeExecutionStats( llm_call_count=retry_count + 1, llm_retry_count=retry_count, + provider_cost=total_provider_cost, ) ) yield "response", {"response": response_text} @@ -1502,6 +1654,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): error_feedback_message = f"Error calling LLM: {e}" + # All retries exhausted or user-error break: persist accumulated cost so + # the executor can still charge/report the spend even on failure. + if total_provider_cost is not None: + self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost)) raise RuntimeError(error_feedback_message) def response_format_instructions( diff --git a/autogpt_platform/backend/backend/blocks/orchestrator.py b/autogpt_platform/backend/backend/blocks/orchestrator.py index 8f9037cc72..b2a6df8481 100644 --- a/autogpt_platform/backend/backend/blocks/orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/orchestrator.py @@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext from backend.data.model import NodeExecutionStats, SchemaField from backend.util import json from backend.util.clients import get_database_manager_async_client +from backend.util.exceptions import InsufficientBalanceError from backend.util.prompt import MAIN_OBJECTIVE_PREFIX from backend.util.security import SENSITIVE_FIELD_NAMES from backend.util.tool_call_loop import ( @@ -251,8 +252,13 @@ def _convert_raw_response_to_dict( # Already a dict (from tests or some providers) return raw_response elif _is_responses_api_object(raw_response): - # OpenAI Responses API: extract individual output items - items = [json.to_dict(item) for item in raw_response.output] + # OpenAI Responses API: extract individual output items. + # Strip 'status' — it's a response-only field that OpenAI rejects + # when the item is sent back as input on the next API call. + items = [ + {k: v for k, v in json.to_dict(item).items() if k != "status"} + for item in raw_response.output + ] return items if items else [{"role": "assistant", "content": ""}] else: # Chat Completions / Anthropic return message objects @@ -359,10 +365,31 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None: class OrchestratorBlock(Block): + """A block that uses a language model to orchestrate tool calls. + + Supports both single-shot and iterative agent mode execution. + + **InsufficientBalanceError propagation contract**: ``InsufficientBalanceError`` + (IBE) must always re-raise through every ``except`` block in this class. + Swallowing IBE would let the agent loop continue with unpaid work. Every + exception handler that catches ``Exception`` includes an explicit IBE + re-raise carve-out for this reason. """ - A block that uses a language model to orchestrate tool calls, supporting both - single-shot and iterative agent mode execution. - """ + + def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int: + """Charge one extra runtime cost per LLM call beyond the first. + + In agent mode each iteration makes one LLM call. The first is already + covered by charge_usage(); this returns the number of additional + credits so the executor can bill the remaining calls post-completion. + + SDK-mode exemption: when the block runs via _execute_tools_sdk_mode, + the SDK manages its own conversation loop and only exposes aggregate + usage. We hardcode llm_call_count=1 there (the SDK does not report a + per-turn call count), so this method always returns 0 for SDK-mode + executions. Per-iteration billing does not apply to SDK mode. + """ + return max(0, execution_stats.llm_call_count - 1) # MCP server name used by the Claude Code SDK execution mode. Keep in sync # with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode. @@ -844,7 +871,10 @@ class OrchestratorBlock(Block): NodeExecutionStats( input_token_count=resp.prompt_tokens, output_token_count=resp.completion_tokens, + cache_read_token_count=resp.cache_read_tokens, + cache_creation_token_count=resp.cache_creation_tokens, llm_call_count=1, + provider_cost=resp.provider_cost, ) ) @@ -1069,7 +1099,10 @@ class OrchestratorBlock(Block): input_data=input_value, ) - assert node_exec_result is not None, "node_exec_result should not be None" + if node_exec_result is None: + raise RuntimeError( + f"upsert_execution_input returned None for node {sink_node_id}" + ) # Create NodeExecutionEntry for execution manager node_exec_entry = NodeExecutionEntry( @@ -1104,15 +1137,86 @@ class OrchestratorBlock(Block): task=node_exec_future, ) - # Execute the node directly since we're in the Orchestrator context - node_exec_future.set_result( - await execution_processor.on_node_execution( + # Execute the node directly since we're in the Orchestrator context. + # Wrap in try/except so the future is always resolved, even on + # error — an unresolved Future would block anything awaiting it. + # + # on_node_execution is decorated with @async_error_logged(swallow=True), + # which catches BaseException and returns None rather than raising. + # Treat a None return as a failure: set_exception so the future + # carries an error state rather than a None result, and return an + # error response so the LLM knows the tool failed. + try: + tool_node_stats = await execution_processor.on_node_execution( node_exec=node_exec_entry, node_exec_progress=node_exec_progress, nodes_input_masks=None, graph_stats_pair=graph_stats_pair, ) - ) + if tool_node_stats is None: + nil_err = RuntimeError( + f"on_node_execution returned None for node {sink_node_id} " + "(error was swallowed by @async_error_logged)" + ) + node_exec_future.set_exception(nil_err) + resp = _create_tool_response( + tool_call.id, + "Tool execution returned no result", + responses_api=responses_api, + ) + resp["_is_error"] = True + return resp + node_exec_future.set_result(tool_node_stats) + except Exception as exec_err: + node_exec_future.set_exception(exec_err) + raise + + # Charge user credits AFTER successful tool execution. Tools + # spawned by the orchestrator bypass the main execution queue + # (where _charge_usage is called), so we must charge here to + # avoid free tool execution. Charging post-completion (vs. + # pre-execution) avoids billing users for failed tool calls. + # Skipped for dry runs. + # + # `error is None` intentionally excludes both Exception and + # BaseException subclasses (e.g. CancelledError) so cancelled + # or terminated tool runs are not billed. + # + # Billing errors (including non-balance exceptions) are kept + # in a separate try/except so they are never silently swallowed + # by the generic tool-error handler below. + if ( + not execution_params.execution_context.dry_run + and tool_node_stats.error is None + ): + try: + tool_cost, _ = await execution_processor.charge_node_usage( + node_exec_entry, + ) + except InsufficientBalanceError: + # IBE must propagate — see OrchestratorBlock class docstring. + # Log the billing failure here so the discarded tool result + # is traceable before the loop aborts. + logger.warning( + "Insufficient balance charging for tool node %s after " + "successful execution; agent loop will be aborted", + sink_node_id, + ) + raise + except Exception: + # Non-billing charge failures (DB outage, network, etc.) + # must NOT propagate to the outer except handler because + # the tool itself succeeded. Re-raising would mark the + # tool as failed (_is_error=True), causing the LLM to + # retry side-effectful operations. Log and continue. + logger.exception( + "Unexpected error charging for tool node %s; " + "tool execution was successful", + sink_node_id, + ) + tool_cost = 0 + if tool_cost > 0: + self.merge_stats(NodeExecutionStats(extra_cost=tool_cost)) # Get outputs from database after execution completes using database manager client node_outputs = await db_client.get_execution_outputs_by_node_exec_id( @@ -1125,18 +1229,26 @@ class OrchestratorBlock(Block): if node_outputs else "Tool executed successfully" ) - return _create_tool_response( + resp = _create_tool_response( tool_call.id, tool_response_content, responses_api=responses_api ) + resp["_is_error"] = False + return resp + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: - logger.warning("Tool execution with manager failed: %s", e) - # Return error response - return _create_tool_response( + logger.warning("Tool execution with manager failed: %s", e, exc_info=True) + # Return a generic error to the LLM — internal exception messages + # may contain server paths, DB details, or infrastructure info. + resp = _create_tool_response( tool_call.id, - f"Tool execution failed: {e}", + "Tool execution failed due to an internal error", responses_api=responses_api, ) + resp["_is_error"] = True + return resp async def _agent_mode_llm_caller( self, @@ -1236,13 +1348,16 @@ class OrchestratorBlock(Block): content = str(raw_content) else: content = "Tool executed successfully" - tool_failed = content.startswith("Tool execution failed:") + tool_failed = result.get("_is_error", True) return ToolCallResult( tool_call_id=tool_call.id, tool_name=tool_call.name, content=content, is_error=tool_failed, ) + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: logger.error("Tool execution failed: %s", e) return ToolCallResult( @@ -1362,9 +1477,13 @@ class OrchestratorBlock(Block): "arguments": tc.arguments, }, ) + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: - # Catch all errors (validation, network, API) so that the block - # surfaces them as user-visible output instead of crashing. + # Catch all OTHER errors (validation, network, API) so that + # the block surfaces them as user-visible output instead of + # crashing. yield "error", str(e) return @@ -1442,11 +1561,14 @@ class OrchestratorBlock(Block): text = content else: text = json.dumps(content) - tool_failed = text.startswith("Tool execution failed:") + tool_failed = result.get("_is_error", True) return { "content": [{"type": "text", "text": text}], "isError": tool_failed, } + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: logger.error("SDK tool execution failed: %s", e) return { @@ -1572,6 +1694,7 @@ class OrchestratorBlock(Block): conversation: list[dict[str, Any]] = list(prompt) # Start with input prompt total_prompt_tokens = 0 total_completion_tokens = 0 + total_cost_usd: float | None = None sdk_error: Exception | None = None try: @@ -1715,6 +1838,8 @@ class OrchestratorBlock(Block): total_completion_tokens += getattr( sdk_msg.usage, "output_tokens", 0 ) + if sdk_msg.total_cost_usd is not None: + total_cost_usd = sdk_msg.total_cost_usd finally: if pending_task is not None and not pending_task.done(): pending_task.cancel() @@ -1722,11 +1847,15 @@ class OrchestratorBlock(Block): await pending_task except (asyncio.CancelledError, StopAsyncIteration): pass + except InsufficientBalanceError: + # IBE must propagate — see class docstring. The `finally` + # block below still runs and records partial token usage. + raise except Exception as e: - # Surface SDK errors as user-visible output instead of crashing, - # consistent with _execute_tools_agent_mode error handling. - # Don't return yet — fall through to merge_stats below so - # partial token usage is always recorded. + # Surface OTHER SDK errors as user-visible output instead + # of crashing, consistent with _execute_tools_agent_mode + # error handling. Don't return yet — fall through to + # merge_stats below so partial token usage is always recorded. sdk_error = e finally: # Always record usage stats, even on error. The SDK may have @@ -1734,12 +1863,17 @@ class OrchestratorBlock(Block): # those stats would under-count resource usage. # llm_call_count=1 is approximate; the SDK manages its own # multi-turn loop and only exposes aggregate usage. - if total_prompt_tokens > 0 or total_completion_tokens > 0: + if ( + total_prompt_tokens > 0 + or total_completion_tokens > 0 + or total_cost_usd is not None + ): self.merge_stats( NodeExecutionStats( input_token_count=total_prompt_tokens, output_token_count=total_completion_tokens, llm_call_count=1, + provider_cost=total_cost_usd, ) ) # Clean up execution-specific working directory. diff --git a/autogpt_platform/backend/backend/blocks/perplexity.py b/autogpt_platform/backend/backend/blocks/perplexity.py index a8b137ce2b..abdbadef91 100644 --- a/autogpt_platform/backend/backend/blocks/perplexity.py +++ b/autogpt_platform/backend/backend/blocks/perplexity.py @@ -98,14 +98,23 @@ class PerplexityBlock(Block): return _sanitize_perplexity_model(v) @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: """Sanitize the model field before JSON schema validation so that invalid values are replaced with the default instead of raising a - BlockInputError.""" + BlockInputError. + + Signature matches ``BlockSchema.validate_data`` (including the + optional ``exclude_fields`` kwarg added for dry-run credential + bypass) so Pyright doesn't flag this as an incompatible override. + """ model_value = data.get("model") if model_value is not None: data["model"] = _sanitize_perplexity_model(model_value).value - return super().validate_data(data) + return super().validate_data(data, exclude_fields=exclude_fields) system_prompt: str = SchemaField( title="System Prompt", diff --git a/autogpt_platform/backend/backend/blocks/smartlead/campaign.py b/autogpt_platform/backend/backend/blocks/smartlead/campaign.py index 302a38f4db..ce900a2d09 100644 --- a/autogpt_platform/backend/backend/blocks/smartlead/campaign.py +++ b/autogpt_platform/backend/backend/blocks/smartlead/campaign.py @@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import ( SaveSequencesResponse, Sequence, ) -from backend.data.model import CredentialsField, SchemaField +from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField class CreateCampaignBlock(Block): @@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block): response = await self.add_leads_to_campaign( input_data.campaign_id, input_data.lead_list, credentials ) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(input_data.lead_list)), + provider_cost_type="items", + ) + ) yield "campaign_id", input_data.campaign_id yield "upload_count", response.upload_count diff --git a/autogpt_platform/backend/backend/blocks/test/test_autopilot.py b/autogpt_platform/backend/backend/blocks/test/test_autopilot.py index 2526bf1455..5fb468fb03 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_autopilot.py +++ b/autogpt_platform/backend/backend/blocks/test/test_autopilot.py @@ -1,13 +1,14 @@ """Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths.""" import asyncio -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest from backend.blocks.autopilot import ( AUTOPILOT_BLOCK_ID, AutoPilotBlock, + SubAgentRecursionError, _autopilot_recursion_depth, _autopilot_recursion_limit, _check_recursion, @@ -57,7 +58,7 @@ class TestCheckRecursion: try: t2 = _check_recursion(2) try: - with pytest.raises(RuntimeError, match="recursion depth limit"): + with pytest.raises(SubAgentRecursionError): _check_recursion(2) finally: _reset_recursion(t2) @@ -71,7 +72,7 @@ class TestCheckRecursion: t2 = _check_recursion(10) # inner wants 10, but inherited is 2 try: # depth is now 2, limit is min(10, 2) = 2 → should raise - with pytest.raises(RuntimeError, match="recursion depth limit"): + with pytest.raises(SubAgentRecursionError): _check_recursion(10) finally: _reset_recursion(t2) @@ -81,7 +82,7 @@ class TestCheckRecursion: def test_limit_of_one_blocks_immediately_on_second_call(self): t1 = _check_recursion(1) try: - with pytest.raises(RuntimeError): + with pytest.raises(SubAgentRecursionError): _check_recursion(1) finally: _reset_recursion(t1) @@ -175,6 +176,29 @@ class TestRunValidation: assert outputs["session_id"] == "sess-cancel" assert "cancelled" in outputs.get("error", "").lower() + @pytest.mark.asyncio + async def test_dry_run_inherited_from_execution_context(self, block): + """execution_context.dry_run=True must be OR-ed into create_session dry_run + so that nested AutoPilot sessions simulate even when input_data.dry_run=False. + """ + mock_result = ( + "ok", + [], + "[]", + "sess-dry", + {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + block.execute_copilot = AsyncMock(return_value=mock_result) + block.create_session = AsyncMock(return_value="sess-dry") + + input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=False) + ctx = _make_context() + ctx.dry_run = True # outer execution is dry_run + async for _ in block.run(input_data, execution_context=ctx): + pass + + block.create_session.assert_called_once_with(ctx.user_id, dry_run=True) + @pytest.mark.asyncio async def test_existing_session_id_skips_create(self, block): """When session_id is provided, create_session should not be called.""" @@ -221,3 +245,171 @@ class TestBlockRegistration: # The field should exist (inherited) but there should be no explicit # redefinition. We verify by checking the class __annotations__ directly. assert "error" not in AutoPilotBlock.Output.__annotations__ + + +# --------------------------------------------------------------------------- +# Recovery enqueue integration tests +# --------------------------------------------------------------------------- + + +class TestRecoveryEnqueue: + """Tests that run() enqueues orphaned sessions for recovery on failure.""" + + @pytest.fixture + def block(self): + return AutoPilotBlock() + + @pytest.mark.asyncio + async def test_recovery_enqueued_on_transient_exception(self, block): + """A generic exception should trigger _enqueue_for_recovery.""" + block.execute_copilot = AsyncMock(side_effect=RuntimeError("network error")) + block.create_session = AsyncMock(return_value="sess-recover") + + input_data = block.Input(prompt="do work", max_recursion_depth=3) + ctx = _make_context() + + with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue: + mock_enqueue.return_value = None + outputs = {} + async for name, value in block.run(input_data, execution_context=ctx): + outputs[name] = value + + assert "network error" in outputs.get("error", "") + mock_enqueue.assert_awaited_once_with( + "sess-recover", + ctx.user_id, + "do work", + False, + ) + + @pytest.mark.asyncio + async def test_recovery_not_enqueued_for_recursion_limit(self, block): + """Recursion limit errors are deliberate — no recovery enqueue.""" + block.execute_copilot = AsyncMock( + side_effect=SubAgentRecursionError( + "AutoPilot recursion depth limit reached (3). " + "The autopilot has called itself too many times." + ) + ) + block.create_session = AsyncMock(return_value="sess-rec-limit") + + input_data = block.Input(prompt="recurse", max_recursion_depth=3) + ctx = _make_context() + + with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue: + async for _ in block.run(input_data, execution_context=ctx): + pass + + mock_enqueue.assert_not_awaited() + + @pytest.mark.asyncio + async def test_recovery_not_enqueued_for_dry_run(self, block): + """dry_run=True sessions must not be enqueued (no real consumers).""" + block.execute_copilot = AsyncMock(side_effect=RuntimeError("transient")) + block.create_session = AsyncMock(return_value="sess-dry-fail") + + input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=True) + ctx = _make_context() + + with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue: + mock_enqueue.return_value = None + async for _ in block.run(input_data, execution_context=ctx): + pass + + # _enqueue_for_recovery is called with dry_run=True, + # so the inner guard returns early without publishing to the queue. + mock_enqueue.assert_awaited_once() + positional = mock_enqueue.call_args_list[0][0] + assert positional[3] is True # dry_run=True + + @pytest.mark.asyncio + async def test_recovery_enqueue_failure_does_not_mask_original_error(self, block): + """If _enqueue_for_recovery itself raises, the original error is still yielded.""" + block.execute_copilot = AsyncMock(side_effect=ValueError("original")) + block.create_session = AsyncMock(return_value="sess-enq-fail") + + input_data = block.Input(prompt="hello", max_recursion_depth=3) + ctx = _make_context() + + async def _failing_enqueue(*args, **kwargs): + raise OSError("rabbitmq down") + + with patch( + "backend.blocks.autopilot._enqueue_for_recovery", + side_effect=_failing_enqueue, + ): + outputs = {} + async for name, value in block.run(input_data, execution_context=ctx): + outputs[name] = value + + # Original error must still be surfaced despite the enqueue failure + assert outputs.get("error") == "original" + assert outputs.get("session_id") == "sess-enq-fail" + + @pytest.mark.asyncio + async def test_recovery_uses_dry_run_from_context(self, block): + """execution_context.dry_run=True is OR-ed into the dry_run arg.""" + block.execute_copilot = AsyncMock(side_effect=RuntimeError("fail")) + block.create_session = AsyncMock(return_value="sess-ctx-dry") + + input_data = block.Input(prompt="test", max_recursion_depth=3, dry_run=False) + ctx = _make_context() + ctx.dry_run = True # outer execution is dry_run + + with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue: + mock_enqueue.return_value = None + async for _ in block.run(input_data, execution_context=ctx): + pass + + mock_enqueue.assert_awaited_once() + positional = mock_enqueue.call_args_list[0][0] + assert positional[3] is True # dry_run=True + + @pytest.mark.asyncio + async def test_recovery_uses_effective_prompt_with_system_context(self, block): + """When system_context is set, _enqueue_for_recovery receives the + effective_prompt (system_context prepended) so the dedup check in + maybe_append_user_message passes on replay.""" + block.execute_copilot = AsyncMock(side_effect=RuntimeError("e2b timeout")) + block.create_session = AsyncMock(return_value="sess-sys-ctx") + + input_data = block.Input( + prompt="do work", + system_context="Be concise.", + max_recursion_depth=3, + ) + ctx = _make_context() + + with patch("backend.blocks.autopilot._enqueue_for_recovery") as mock_enqueue: + mock_enqueue.return_value = None + async for _ in block.run(input_data, execution_context=ctx): + pass + + mock_enqueue.assert_awaited_once() + positional = mock_enqueue.call_args_list[0][0] + assert positional[2] == "[System Context: Be concise.]\n\ndo work" + + @pytest.mark.asyncio + async def test_recovery_cancelled_error_still_yields_error(self, block): + """CancelledError during _enqueue_for_recovery still yields the error output.""" + block.execute_copilot = AsyncMock(side_effect=RuntimeError("e2b stall")) + block.create_session = AsyncMock(return_value="sess-cancel") + + async def _cancelled_enqueue(*args, **kwargs): + raise asyncio.CancelledError + + outputs = {} + with patch( + "backend.blocks.autopilot._enqueue_for_recovery", + side_effect=_cancelled_enqueue, + ): + with pytest.raises(asyncio.CancelledError): + async for name, value in block.run( + block.Input(prompt="do work", max_recursion_depth=3), + execution_context=_make_context(), + ): + outputs[name] = value + + # error must be yielded even when recovery raises CancelledError + assert outputs.get("error") == "e2b stall" + assert outputs.get("session_id") == "sess-cancel" diff --git a/autogpt_platform/backend/backend/blocks/test/test_llm.py b/autogpt_platform/backend/backend/blocks/test/test_llm.py index 9471095fef..f7be1e100f 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_llm.py +++ b/autogpt_platform/backend/backend/blocks/test/test_llm.py @@ -46,6 +46,110 @@ class TestLLMStatsTracking: assert response.completion_tokens == 20 assert response.response == "Test response" + @pytest.mark.asyncio + async def test_llm_call_anthropic_returns_cache_tokens(self): + """Test that llm_call returns cache read/creation tokens from Anthropic.""" + from pydantic import SecretStr + + import backend.blocks.llm as llm + from backend.data.model import APIKeyCredentials + + anthropic_creds = APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + expires_at=None, + ) + + mock_content_block = MagicMock() + mock_content_block.type = "text" + mock_content_block.text = "Test anthropic response" + + mock_usage = MagicMock() + mock_usage.input_tokens = 15 + mock_usage.output_tokens = 25 + mock_usage.cache_read_input_tokens = 100 + mock_usage.cache_creation_input_tokens = 50 + + mock_response = MagicMock() + mock_response.content = [mock_content_block] + mock_response.usage = mock_usage + mock_response.stop_reason = "end_turn" + + with ( + patch("anthropic.AsyncAnthropic") as mock_anthropic, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "" + mock_client = AsyncMock() + mock_anthropic.return_value = mock_client + mock_client.messages.create = AsyncMock(return_value=mock_response) + + response = await llm.llm_call( + credentials=anthropic_creds, + llm_model=llm.LlmModel.CLAUDE_3_HAIKU, + prompt=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + assert isinstance(response, llm.LLMResponse) + assert response.prompt_tokens == 15 + assert response.completion_tokens == 25 + assert response.cache_read_tokens == 100 + assert response.cache_creation_tokens == 50 + assert response.response == "Test anthropic response" + + @pytest.mark.asyncio + async def test_anthropic_routes_through_openrouter_when_key_present(self): + """When open_router_api_key is set, Anthropic models route via OpenRouter.""" + from pydantic import SecretStr + + import backend.blocks.llm as llm + from backend.data.model import APIKeyCredentials + + anthropic_creds = APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + ) + + mock_choice = MagicMock() + mock_choice.message.content = "routed response" + mock_choice.message.tool_calls = None + + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = mock_usage + + mock_create = AsyncMock(return_value=mock_response) + + with ( + patch("openai.AsyncOpenAI") as mock_openai, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "sk-or-test-key" + mock_client = MagicMock() + mock_openai.return_value = mock_client + mock_client.chat.completions.create = mock_create + + await llm.llm_call( + credentials=anthropic_creds, + llm_model=llm.LlmModel.CLAUDE_3_HAIKU, + prompt=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + # Verify OpenAI client was used (not Anthropic SDK) and model was prefixed + mock_openai.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307" + @pytest.mark.asyncio async def test_ai_structured_response_block_tracks_stats(self): """Test that AIStructuredResponseGeneratorBlock correctly tracks stats.""" @@ -199,6 +303,139 @@ class TestLLMStatsTracking: assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2 assert block.execution_stats.llm_retry_count == 1 + @pytest.mark.asyncio + async def test_retry_cost_accumulates_across_attempts(self): + """provider_cost accumulates across all retry attempts. + + Each LLM call incurs a real cost, including failed validation attempts. + The total cost is the sum of all attempts so no billed USD is lost. + """ + import backend.blocks.llm as llm + + block = llm.AIStructuredResponseGeneratorBlock() + call_count = 0 + + async def mock_llm_call(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First attempt: fails validation, returns cost $0.01 + return llm.LLMResponse( + raw_response="", + prompt=[], + response='{"wrong": "key"}', + tool_calls=None, + prompt_tokens=10, + completion_tokens=5, + reasoning=None, + provider_cost=0.01, + ) + # Second attempt: succeeds, returns cost $0.02 + return llm.LLMResponse( + raw_response="", + prompt=[], + response='{"key1": "value1", "key2": "value2"}', + tool_calls=None, + prompt_tokens=20, + completion_tokens=10, + reasoning=None, + provider_cost=0.02, + ) + + block.llm_call = mock_llm_call # type: ignore + + input_data = llm.AIStructuredResponseGeneratorBlock.Input( + prompt="Test prompt", + expected_format={"key1": "desc1", "key2": "desc2"}, + model=llm.DEFAULT_LLM_MODEL, + credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore + retry=2, + ) + + with patch("secrets.token_hex", return_value="test123456"): + async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): + pass + + # provider_cost accumulates across all attempts: $0.01 + $0.02 = $0.03 + assert block.execution_stats.provider_cost == pytest.approx(0.03) + # Tokens from both attempts accumulate + assert block.execution_stats.input_token_count == 30 + assert block.execution_stats.output_token_count == 15 + + @pytest.mark.asyncio + async def test_cache_tokens_accumulated_in_stats(self): + """Cache read/creation tokens are tracked per-attempt and accumulated.""" + import backend.blocks.llm as llm + + block = llm.AIStructuredResponseGeneratorBlock() + + async def mock_llm_call(*args, **kwargs): + return llm.LLMResponse( + raw_response="", + prompt=[], + response='{"key1": "v1", "key2": "v2"}', + tool_calls=None, + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=20, + cache_creation_tokens=8, + reasoning=None, + provider_cost=0.005, + ) + + block.llm_call = mock_llm_call # type: ignore + + input_data = llm.AIStructuredResponseGeneratorBlock.Input( + prompt="Test prompt", + expected_format={"key1": "desc1", "key2": "desc2"}, + model=llm.DEFAULT_LLM_MODEL, + credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore + retry=1, + ) + + with patch("secrets.token_hex", return_value="tok123456"): + async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): + pass + + assert block.execution_stats.cache_read_token_count == 20 + assert block.execution_stats.cache_creation_token_count == 8 + + @pytest.mark.asyncio + async def test_failure_path_persists_accumulated_cost(self): + """When all retries are exhausted, accumulated provider_cost is preserved.""" + import backend.blocks.llm as llm + + block = llm.AIStructuredResponseGeneratorBlock() + + async def mock_llm_call(*args, **kwargs): + return llm.LLMResponse( + raw_response="", + prompt=[], + response="not valid json at all", + tool_calls=None, + prompt_tokens=10, + completion_tokens=5, + reasoning=None, + provider_cost=0.01, + ) + + block.llm_call = mock_llm_call # type: ignore + + input_data = llm.AIStructuredResponseGeneratorBlock.Input( + prompt="Test prompt", + expected_format={"key1": "desc1"}, + model=llm.DEFAULT_LLM_MODEL, + credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore + retry=2, + ) + + with pytest.raises(RuntimeError): + async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): + pass + + # Both retry attempts each cost $0.01, total $0.02 + assert block.execution_stats.provider_cost == pytest.approx(0.02) + @pytest.mark.asyncio async def test_ai_text_summarizer_multiple_chunks(self): """Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks.""" @@ -987,3 +1224,295 @@ class TestLlmModelMissing: assert ( llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO ) + + +class TestExtractOpenRouterCost: + """Tests for extract_openrouter_cost — the x-total-cost header parser.""" + + def _mk_response(self, headers: dict | None): + response = MagicMock() + if headers is None: + response._response = None + else: + raw = MagicMock() + raw.headers = headers + response._response = raw + return response + + def test_extracts_numeric_cost(self): + response = self._mk_response({"x-total-cost": "0.0042"}) + assert llm.extract_openrouter_cost(response) == 0.0042 + + def test_returns_none_when_header_missing(self): + response = self._mk_response({}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_when_header_empty_string(self): + response = self._mk_response({"x-total-cost": ""}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_when_header_non_numeric(self): + response = self._mk_response({"x-total-cost": "not-a-number"}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_when_no_response_attr(self): + response = MagicMock(spec=[]) # no _response attr + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_when_raw_is_none(self): + response = self._mk_response(None) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_when_raw_has_no_headers(self): + response = MagicMock() + response._response = MagicMock(spec=[]) # no headers attr + assert llm.extract_openrouter_cost(response) is None + + def test_returns_zero_for_zero_cost(self): + """Zero-cost is a valid value (free tier) and must not become None.""" + response = self._mk_response({"x-total-cost": "0"}) + assert llm.extract_openrouter_cost(response) == 0.0 + + def test_returns_none_for_inf(self): + response = self._mk_response({"x-total-cost": "inf"}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_for_negative_inf(self): + response = self._mk_response({"x-total-cost": "-inf"}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_for_nan(self): + response = self._mk_response({"x-total-cost": "nan"}) + assert llm.extract_openrouter_cost(response) is None + + def test_returns_none_for_negative_cost(self): + response = self._mk_response({"x-total-cost": "-0.005"}) + assert llm.extract_openrouter_cost(response) is None + + +class TestAnthropicCacheControl: + """Verify that llm_call attaches cache_control to the system prompt block + and to the last tool definition when calling the Anthropic API.""" + + @pytest.fixture(autouse=True) + def disable_openrouter_routing(self): + """Ensure tests exercise the direct-Anthropic path by suppressing the + OpenRouter API key. Without this, a local .env with OPEN_ROUTER_API_KEY + set would silently reroute all Anthropic calls through OpenRouter, + bypassing the cache_control code under test.""" + with patch("backend.blocks.llm.settings") as mock_settings: + mock_settings.secrets.open_router_api_key = "" + yield mock_settings + + def _make_anthropic_credentials(self) -> llm.APIKeyCredentials: + from pydantic import SecretStr + + return llm.APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + expires_at=None, + ) + + @pytest.mark.asyncio + async def test_system_prompt_sent_as_block_with_cache_control(self): + """The system prompt is wrapped in a structured block with cache_control ephemeral.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="hello")] + mock_resp.usage = MagicMock(input_tokens=5, output_tokens=3) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "You are an assistant."}, + {"role": "user", "content": "Hello"}, + ], + max_tokens=100, + ) + + system_arg = captured_kwargs.get("system") + assert isinstance(system_arg, list), "system should be a list of blocks" + assert len(system_arg) == 1 + block = system_arg[0] + assert block["type"] == "text" + assert block["text"] == "You are an assistant." + assert block.get("cache_control") == {"type": "ephemeral"} + + @pytest.mark.asyncio + async def test_last_tool_gets_cache_control(self): + """cache_control is placed on the last tool in the Anthropic tools list.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=10, output_tokens=5) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + tools = [ + { + "type": "function", + "function": { + "name": "tool_a", + "description": "First tool", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + { + "type": "function", + "function": { + "name": "tool_b", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + ] + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "System."}, + {"role": "user", "content": "Do something"}, + ], + max_tokens=100, + tools=tools, + ) + + an_tools = captured_kwargs.get("tools") + assert isinstance(an_tools, list) + assert len(an_tools) == 2 + assert ( + an_tools[0].get("cache_control") is None + ), "Only last tool gets cache_control" + assert an_tools[-1].get("cache_control") == {"type": "ephemeral"} + + @pytest.mark.asyncio + async def test_no_tools_no_cache_control_on_tools(self): + """When there are no tools, the Anthropic call receives anthropic.NOT_GIVEN for tools.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "System."}, + {"role": "user", "content": "Hello"}, + ], + max_tokens=100, + tools=None, + ) + + import anthropic + + tools_arg = captured_kwargs.get("tools") + assert ( + tools_arg is anthropic.NOT_GIVEN + ), "Empty tools should pass anthropic.NOT_GIVEN sentinel" + + @pytest.mark.asyncio + async def test_empty_system_prompt_omits_system_key(self): + """When sysprompt is empty, the 'system' key must not be sent to Anthropic. + + Anthropic rejects empty text blocks; the guard in llm_call must ensure + the system argument is omitted entirely when no system messages are present. + """ + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=3, output_tokens=2) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[{"role": "user", "content": "Hi"}], + max_tokens=50, + ) + + assert ( + "system" not in captured_kwargs + ), "system must be omitted when sysprompt is empty to avoid Anthropic 400" + + @pytest.mark.asyncio + async def test_whitespace_only_system_prompt_omits_system_key(self): + """Whitespace-only system content is treated as empty and omitted. + + The guard in llm_call uses sysprompt.strip() so a prompt consisting of + only whitespace should NOT reach the Anthropic API (it would be rejected + as an empty text block). + """ + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=3, output_tokens=2) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": " \n\t "}, + {"role": "user", "content": "Hi"}, + ], + max_tokens=50, + ) + + assert ( + "system" not in captured_kwargs + ), "whitespace-only sysprompt must be omitted to avoid Anthropic 400" diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py index 55f137428f..2eb27012dc 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py @@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode(): mock_execution_processor.on_node_execution = AsyncMock( return_value=mock_node_stats ) + # Mock charge_node_usage (called after successful tool execution). + # Returns (cost, remaining_balance). Must be AsyncMock because it is + # an async method and is directly awaited in _execute_single_tool_with_manager. + # Use a non-zero cost so the merge_stats branch is exercised. + mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990)) # Mock the get_execution_outputs_by_node_exec_id method mock_db_client.get_execution_outputs_by_node_exec_id.return_value = { @@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode(): # Verify tool was executed via execution processor assert mock_execution_processor.on_node_execution.call_count == 1 + # Verify charge_node_usage was actually called for the successful + # tool execution — this guards against regressions where the + # post-execution tool charging is accidentally removed. + assert mock_execution_processor.charge_node_usage.call_count == 1 + @pytest.mark.asyncio async def test_orchestrator_traditional_mode_default(): diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py index ac4fa0710b..f2242ea527 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py @@ -306,6 +306,9 @@ async def test_output_yielding_with_dynamic_fields(): mock_response.raw_response = {"role": "assistant", "content": "test"} mock_response.prompt_tokens = 100 mock_response.completion_tokens = 50 + mock_response.cache_read_tokens = 0 + mock_response.cache_creation_tokens = 0 + mock_response.provider_cost = None # Mock the LLM call with patch( @@ -638,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation(): mock_execution_processor.on_node_execution.return_value = ( mock_node_stats ) + # Mock charge_node_usage (called after successful tool execution). + # Must be AsyncMock because it is async and is awaited in + # _execute_single_tool_with_manager — a plain MagicMock would + # return a non-awaitable tuple and TypeError out, then be + # silently swallowed by the orchestrator's catch-all. + mock_execution_processor.charge_node_usage = AsyncMock( + return_value=(0, 0) + ) async for output_name, output_value in block.run( input_data, diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py new file mode 100644 index 0000000000..441bc08a42 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py @@ -0,0 +1,1020 @@ +"""Tests for OrchestratorBlock per-iteration cost charging. + +The OrchestratorBlock in agent mode makes multiple LLM calls in a single +node execution. The executor uses ``Block.extra_runtime_cost`` to detect +this and charge ``base_cost * (llm_call_count - 1)`` extra credits after +the block completes. +""" + +import threading +from collections import defaultdict +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.blocks._base import Block +from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock +from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.model import NodeExecutionStats +from backend.executor import billing, manager +from backend.util.exceptions import InsufficientBalanceError + +# ── extra_runtime_cost hook ──────────────────────────────────────── + + +class _NoOpBlock(Block): + """Minimal concrete Block subclass that does not override extra_runtime_cost.""" + + def __init__(self): + super().__init__( + id="00000000-0000-0000-0000-000000000001", description="No-op test block" + ) + + def run(self, input_data, **kwargs): # type: ignore[override] + yield "out", {} + + +class TestExtraRuntimeCost: + """OrchestratorBlock opts into per-LLM-call billing via extra_runtime_cost.""" + + def test_orchestrator_returns_nonzero_for_multiple_calls(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=3) + assert block.extra_runtime_cost(stats) == 2 + + def test_orchestrator_returns_zero_for_single_call(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=1) + assert block.extra_runtime_cost(stats) == 0 + + def test_orchestrator_returns_zero_for_zero_calls(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=0) + assert block.extra_runtime_cost(stats) == 0 + + def test_default_block_returns_zero(self): + """A block that does not override extra_runtime_cost returns 0.""" + block = _NoOpBlock() + stats = NodeExecutionStats(llm_call_count=10) + assert block.extra_runtime_cost(stats) == 0 + + +# ── charge_extra_runtime_cost math ─────────────────────────────────── + + +@pytest.fixture() +def fake_node_exec(): + node_exec = MagicMock() + node_exec.user_id = "u" + node_exec.graph_exec_id = "g" + node_exec.graph_id = "g" + node_exec.node_exec_id = "ne" + node_exec.node_id = "n" + node_exec.block_id = "b" + node_exec.inputs = {} + return node_exec + + +@pytest.fixture() +def patched_processor(monkeypatch): + """ExecutionProcessor with stubbed db client / block lookup helpers. + + Returns the processor and a list of credit amounts spent so tests can + assert on what was charged. + + Note: ``ExecutionProcessor.__new__()`` bypasses ``__init__`` — if + ``__init__`` gains required state in the future this fixture will need + updating. + """ + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 1000 # remaining balance + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + billing, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + return proc, spent + + +class TestChargeExtraRuntimeCost: + @pytest.mark.asyncio + async def test_zero_extra_iterations_charges_nothing( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=0 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_extra_iterations_multiplies_base_cost( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=4 + ) + assert cost == 40 # 4 × 10 + assert balance == 1000 + assert spent == [40] + + @pytest.mark.asyncio + async def test_negative_extra_iterations_charges_nothing( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=-1 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_capped_at_max(self, monkeypatch, fake_node_exec): + """Runaway llm_call_count is capped at _MAX_EXTRA_RUNTIME_COST.""" + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 1000 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + billing, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {}), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cap = billing._MAX_EXTRA_RUNTIME_COST + cost, _ = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=cap * 100 + ) + # Charged at most cap × 10 + assert cost == cap * 10 + assert spent == [cap * 10] + + @pytest.mark.asyncio + async def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec): + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 0 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + billing, "block_usage_cost", lambda block, input_data, **_kw: (0, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=4 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec): + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 0 + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: None) + monkeypatch.setattr( + billing, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_extra_runtime_cost( + fake_node_exec, extra_count=3 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_propagates_insufficient_balance_error( + self, monkeypatch, fake_node_exec + ): + """Out-of-credits errors must propagate, not be silently swallowed.""" + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + raise InsufficientBalanceError( + user_id=user_id, + message="Insufficient balance", + balance=0, + amount=cost, + ) + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + billing, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + with pytest.raises(InsufficientBalanceError): + await proc.charge_extra_runtime_cost(fake_node_exec, extra_count=4) + + +# ── charge_node_usage ────────────────────────────────────────────── + + +class TestChargeNodeUsage: + """charge_node_usage delegates to billing.charge_usage with execution_count=0.""" + + @pytest.mark.asyncio + async def test_delegates_with_zero_execution_count( + self, monkeypatch, fake_node_exec + ): + """Nested tool charges should NOT inflate the per-execution counter.""" + + captured: dict = {} + + def fake_charge_usage(node_exec, execution_count): + captured["execution_count"] = execution_count + captured["node_exec"] = node_exec + return (5, 100) + + def fake_handle_low_balance( + db_client, user_id, current_balance, transaction_cost + ): + pass + + monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) + monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) + monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 5 + assert balance == 100 + assert captured["execution_count"] == 0 + + @pytest.mark.asyncio + async def test_calls_handle_low_balance_when_cost_nonzero( + self, monkeypatch, fake_node_exec + ): + """charge_node_usage should call handle_low_balance when total_cost > 0.""" + + low_balance_calls: list[dict] = [] + + def fake_charge_usage(node_exec, execution_count): + return (10, 50) + + def fake_handle_low_balance( + db_client, user_id, current_balance, transaction_cost + ): + low_balance_calls.append( + { + "user_id": user_id, + "current_balance": current_balance, + "transaction_cost": transaction_cost, + } + ) + + monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) + monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) + monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 10 + assert balance == 50 + assert len(low_balance_calls) == 1 + assert low_balance_calls[0]["user_id"] == "u" + assert low_balance_calls[0]["current_balance"] == 50 + assert low_balance_calls[0]["transaction_cost"] == 10 + + @pytest.mark.asyncio + async def test_skips_handle_low_balance_when_cost_zero( + self, monkeypatch, fake_node_exec + ): + """charge_node_usage should NOT call handle_low_balance when cost is 0.""" + + low_balance_calls: list = [] + + def fake_charge_usage(node_exec, execution_count): + return (0, 200) + + def fake_handle_low_balance( + db_client, user_id, current_balance, transaction_cost + ): + low_balance_calls.append(True) + + monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) + monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) + monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 0 + assert low_balance_calls == [] + + +# ── on_node_execution charging gate ──────────────────────────────── + + +class _FakeNode: + """Minimal stand-in for a ``Node`` object with a block attribute.""" + + def __init__(self, extra_charges: int = 0, block_name: str = "FakeBlock"): + self.block = MagicMock() + self.block.name = block_name + self.block.extra_runtime_cost = MagicMock(return_value=extra_charges) + + +class _FakeExecContext: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + + +def _make_node_exec(dry_run: bool = False) -> MagicMock: + """Build a NodeExecutionEntry-like mock for on_node_execution tests.""" + ne = MagicMock() + ne.user_id = "u" + ne.graph_id = "g" + ne.graph_exec_id = "ge" + ne.node_id = "n" + ne.node_exec_id = "ne" + ne.block_id = "b" + ne.inputs = {} + ne.execution_context = _FakeExecContext(dry_run=dry_run) + return ne + + +@pytest.fixture() +def gated_processor(monkeypatch): + """ExecutionProcessor with on_node_execution's downstream calls stubbed. + + Lets tests flip the gate conditions (status, extra_runtime_cost result, + llm_call_count, dry_run) and observe whether charge_extra_runtime_cost + was called. + """ + + calls: dict[str, list] = { + "charge_extra_runtime_cost": [], + "handle_low_balance": [], + "handle_insufficient_funds_notif": [], + } + + # Stub node lookup + DB client so the wrapper doesn't touch real infra. + fake_db = MagicMock() + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) + monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db) + monkeypatch.setattr(billing, "get_db_client", lambda: fake_db) + # get_block is called by LogMetadata construction in on_node_execution. + monkeypatch.setattr( + manager, + "get_block", + lambda block_id: MagicMock(name="FakeBlock"), + ) + # Persistence + cost logging are not under test here. + monkeypatch.setattr( + manager, + "async_update_node_execution_status", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + manager, + "async_update_graph_execution_state", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + manager, + "log_system_credential_cost", + AsyncMock(return_value=None), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + + # Control the status returned by the inner execution call. + inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3} + + async def fake_inner( + self, + *, + node, + node_exec, + node_exec_progress, + stats, + db_client, + log_metadata, + nodes_input_masks=None, + nodes_to_skip=None, + ): + stats.llm_call_count = inner_result["llm_call_count"] + return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"] + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_on_node_execution", + fake_inner, + ) + + async def fake_charge_extra(node_exec, extra_count): + calls["charge_extra_runtime_cost"].append(extra_count) + return (extra_count * 10, 500) + + monkeypatch.setattr(billing, "charge_extra_runtime_cost", fake_charge_extra) + + def fake_low_balance(db_client, user_id, current_balance, transaction_cost): + calls["handle_low_balance"].append( + { + "user_id": user_id, + "current_balance": current_balance, + "transaction_cost": transaction_cost, + } + ) + + monkeypatch.setattr(billing, "handle_low_balance", fake_low_balance) + + def fake_notif(db_client, user_id, graph_id, e): + calls["handle_insufficient_funds_notif"].append( + {"user_id": user_id, "graph_id": graph_id, "error": e} + ) + + monkeypatch.setattr(billing, "handle_insufficient_funds_notif", fake_notif) + + return proc, calls, inner_result, fake_db, NodeExecutionStats + + +@pytest.mark.asyncio +async def test_on_node_execution_charges_extra_iterations_when_gate_passes( + gated_processor, +): + """COMPLETED + extra_runtime_cost > 0 + not dry_run → charged.""" + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 3 # → extra_charges = 2 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_runtime_cost"] == [2] + # handle_low_balance must be called with the remaining balance returned by + # charge_extra_runtime_cost (500) so users are alerted when balance drops low. + assert len(calls["handle_low_balance"]) == 1 + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_status_not_completed(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.FAILED + inner["llm_call_count"] = 5 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_runtime_cost"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_extra_charges_zero(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 5 + # Block returns 0 extra charges (base class default) + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_runtime_cost"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_dry_run(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 5 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=True), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_runtime_cost"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_insufficient_balance_records_error_and_notifies( + monkeypatch, + gated_processor, +): + """When extra-iteration charging fails with InsufficientBalanceError: + + - the run still reports COMPLETED (the work is already done) + - execution_stats.error is NOT set (would flip node_error_count and + leak balance amounts into persisted node_stats — see manager.py + comment in the IBE handler) + - _handle_insufficient_funds_notif is called so the user is notified + - the structured ERROR log is the alerting hook + """ + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 4 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) + + async def raise_ibe(node_exec, extra_count): + raise InsufficientBalanceError( + user_id=node_exec.user_id, + message="Insufficient balance", + balance=0, + amount=extra_count * 10, + ) + + monkeypatch.setattr(billing, "charge_extra_runtime_cost", raise_ibe) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + result_stats = await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # error stays None — node ran to completion, only the post-hoc + # charge failed. Setting .error would (a) flip node_error_count++ + # creating an "errored COMPLETED node" inconsistency, and (b) leak + # balance amounts into persisted node_stats. + assert result_stats.error is None + # User notification fired. + assert len(calls["handle_insufficient_funds_notif"]) == 1 + assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" + + +# ── Orchestrator _execute_single_tool_with_manager charging gates ── + + +async def _run_tool_exec_with_stats( + *, + dry_run: bool, + tool_stats_error, + charge_node_usage_mock=None, +): + """Invoke _execute_single_tool_with_manager against fully mocked deps + and return (charge_call_count, merge_stats_calls). + + Used to prove the dry_run and error guards around charge_node_usage + behave as documented, and that InsufficientBalanceError propagates. + """ + block = OrchestratorBlock() + + # Mocked async DB client used inside orchestrator. + mock_db_client = AsyncMock() + mock_target_node = MagicMock() + mock_target_node.block_id = "test-block-id" + mock_target_node.input_default = {} + mock_db_client.get_node.return_value = mock_target_node + mock_node_exec_result = MagicMock() + mock_node_exec_result.node_exec_id = "test-tool-exec-id" + mock_db_client.upsert_execution_input.return_value = ( + mock_node_exec_result, + {"query": "t"}, + ) + mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"} + + # ExecutionProcessor mock: on_node_execution returns supplied error. + mock_processor = AsyncMock() + mock_processor.running_node_execution = defaultdict(MagicMock) + mock_processor.execution_stats = MagicMock() + mock_processor.execution_stats_lock = threading.Lock() + mock_node_stats = MagicMock() + mock_node_stats.error = tool_stats_error + mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats) + mock_processor.charge_node_usage = charge_node_usage_mock or AsyncMock( + return_value=(10, 990) + ) + + # Build a tool_info shaped like _build_tool_info_from_args output. + tool_call = MagicMock() + tool_call.id = "call-1" + tool_call.name = "search_keywords" + tool_call.arguments = '{"query":"t"}' + tool_def = { + "type": "function", + "function": { + "name": "search_keywords", + "_sink_node_id": "test-sink-node-id", + "_field_mapping": {}, + "parameters": { + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + tool_info = OrchestratorBlock._build_tool_info_from_args( + tool_call_id="call-1", + tool_name="search_keywords", + tool_args={"query": "t"}, + tool_def=tool_def, + ) + + exec_params = ExecutionParams( + user_id="u", + graph_id="g", + node_id="n", + graph_version=1, + graph_exec_id="ge", + node_exec_id="ne", + execution_context=ExecutionContext( + human_in_the_loop_safe_mode=False, dry_run=dry_run + ), + ) + + with patch( + "backend.blocks.orchestrator.get_database_manager_async_client", + return_value=mock_db_client, + ): + try: + await block._execute_single_tool_with_manager( + tool_info, exec_params, mock_processor, responses_api=False + ) + raised = None + except Exception as e: + raised = e + + return mock_processor.charge_node_usage, raised + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_dry_run(): + """dry_run=True → charge_node_usage is NOT called.""" + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=True, tool_stats_error=None + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_failed_tool(): + """tool_node_stats.error is an Exception → charge_node_usage NOT called.""" + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=False, tool_stats_error=RuntimeError("tool blew up") + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_cancelled_tool(): + """Cancellation (BaseException subclass) → charge_node_usage NOT called. + + Guards the fix for sentry's BaseException concern: the old + `isinstance(error, Exception)` check would have treated CancelledError + as "no error" and billed the user for a terminated run. + """ + import asyncio as _asyncio + + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=False, tool_stats_error=_asyncio.CancelledError() + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_insufficient_balance_propagates(): + """InsufficientBalanceError from charge_node_usage must propagate out. + + If this leaked into a ToolCallResult the LLM loop would keep running + with 'tool failed' errors and the user would get unpaid work. + """ + raising_charge = AsyncMock( + side_effect=InsufficientBalanceError( + user_id="u", message="nope", balance=0, amount=10 + ) + ) + _, raised = await _run_tool_exec_with_stats( + dry_run=False, + tool_stats_error=None, + charge_node_usage_mock=raising_charge, + ) + assert isinstance(raised, InsufficientBalanceError) + + +@pytest.mark.asyncio +async def test_tool_execution_on_node_execution_returns_none_sets_is_error(): + """on_node_execution returning None (swallowed by @async_error_logged) must + result in a tool response with _is_error=True so the LLM loop knows the + tool failed and does not treat a silent error as a successful execution. + """ + block = OrchestratorBlock() + + mock_db_client = AsyncMock() + mock_target_node = MagicMock() + mock_target_node.block_id = "test-block-id" + mock_target_node.input_default = {} + mock_db_client.get_node.return_value = mock_target_node + mock_node_exec_result = MagicMock() + mock_node_exec_result.node_exec_id = "test-tool-exec-id" + mock_db_client.upsert_execution_input.return_value = ( + mock_node_exec_result, + {"query": "t"}, + ) + + mock_processor = AsyncMock() + mock_processor.running_node_execution = defaultdict(MagicMock) + mock_processor.execution_stats = MagicMock() + mock_processor.execution_stats_lock = threading.Lock() + # on_node_execution returns None — simulates @async_error_logged(swallow=True) + # swallowing an internal error + mock_processor.on_node_execution = AsyncMock(return_value=None) + + tool_call = MagicMock() + tool_call.id = "call-none" + tool_call.name = "search_keywords" + tool_call.arguments = '{"query":"t"}' + tool_def = { + "type": "function", + "function": { + "name": "search_keywords", + "_sink_node_id": "test-sink-node-id", + "_field_mapping": {}, + "parameters": { + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + tool_info = OrchestratorBlock._build_tool_info_from_args( + tool_call_id="call-none", + tool_name="search_keywords", + tool_args={"query": "t"}, + tool_def=tool_def, + ) + + exec_params = ExecutionParams( + user_id="u", + graph_id="g", + node_id="n", + graph_version=1, + graph_exec_id="ge", + node_exec_id="ne", + execution_context=ExecutionContext( + human_in_the_loop_safe_mode=False, dry_run=False + ), + ) + + with patch( + "backend.blocks.orchestrator.get_database_manager_async_client", + return_value=mock_db_client, + ): + resp = await block._execute_single_tool_with_manager( + tool_info, exec_params, mock_processor, responses_api=False + ) + + assert resp.get("_is_error") is True + # charge_node_usage must NOT be called for a failed tool execution + mock_processor.charge_node_usage.assert_not_called() + + +# ── on_node_execution FAILED + InsufficientBalanceError notification ── + + +@pytest.mark.asyncio +async def test_on_node_execution_failed_ibe_sends_notification( + monkeypatch, + gated_processor, +): + """When status == FAILED and execution_stats.error is InsufficientBalanceError, + _handle_insufficient_funds_notif must be called. + + This path fires when a nested tool charge inside the orchestrator raises + InsufficientBalanceError, which propagates out of the block's run() generator + and is caught by _on_node_execution's broad except, setting status=FAILED and + execution_stats.error=IBE. on_node_execution's post-execution block then + sends the user notification so they understand why the run stopped. + """ + + proc, calls, inner, fake_db, NodeExecutionStats = gated_processor + ibe = InsufficientBalanceError( + user_id="u", + message="Insufficient balance", + balance=0, + amount=30, + ) + + # Simulate _on_node_execution returning FAILED with IBE in stats.error. + async def fake_inner_failed( + self, + *, + node, + node_exec, + node_exec_progress, + stats, + db_client, + log_metadata, + nodes_input_masks=None, + nodes_to_skip=None, + ): + stats.error = ibe + return MagicMock(wall_time=0.1, cpu_time=0.1), ExecutionStatus.FAILED + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_on_node_execution", + fake_inner_failed, + ) + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # The notification must have fired so the user knows why their run stopped. + assert len(calls["handle_insufficient_funds_notif"]) == 1 + assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" + # charge_extra_runtime_cost must NOT be called — status is FAILED. + assert calls["charge_extra_runtime_cost"] == [] + + +# ── Billing leak: non-IBE exception during extra-iteration charging ── + + +@pytest.mark.asyncio +async def test_on_node_execution_non_ibe_billing_failure_keeps_completed( + monkeypatch, + gated_processor, +): + """When charge_extra_runtime_cost raises a non-IBE exception (e.g. DB outage): + + - execution_stats.error stays None (node ran to completion) + - status stays COMPLETED (work already done) + - the billing_leak error is logged but does not corrupt execution_stats + """ + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 4 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) + + async def raise_conn_error(node_exec, extra_count): + raise ConnectionError("DB connection lost") + + monkeypatch.setattr(billing, "charge_extra_runtime_cost", raise_conn_error) + + stats_pair = ( + MagicMock( + node_count=0, + nodes_cputime=0, + nodes_walltime=0, + cost=0, + node_error_count=0, + ), + threading.Lock(), + ) + result_stats = await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # error stays None — node completed, only billing failed. + assert result_stats.error is None + # No notification was sent (only IBE triggers notification). + assert len(calls["handle_insufficient_funds_notif"]) == 0 + + +# ── _charge_usage with execution_count=0 ── + + +class TestChargeUsageZeroExecutionCount: + """Verify _charge_usage(node_exec, 0) does not invoke execution_usage_cost.""" + + def test_execution_count_zero_skips_execution_tier(self, monkeypatch): + """_charge_usage with execution_count=0 must not call execution_usage_cost.""" + execution_tier_called = [] + + def fake_execution_usage_cost(count): + execution_tier_called.append(count) + return (100, count) + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 500 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + billing, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {}), + ) + monkeypatch.setattr(billing, "execution_usage_cost", fake_execution_usage_cost) + + ne = MagicMock() + ne.user_id = "u" + ne.graph_exec_id = "ge" + ne.graph_id = "g" + ne.node_exec_id = "ne" + ne.node_id = "n" + ne.block_id = "b" + ne.inputs = {} + + total_cost, remaining = billing.charge_usage(ne, 0) + assert total_cost == 10 # block cost only + assert remaining == 500 + assert spent == [10] + # execution_usage_cost must NOT have been called + assert execution_tier_called == [] diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py index b14e24e39f..ac78b6d35b 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py @@ -211,6 +211,30 @@ class TestConvertRawResponseToDict: # A single dict is wrong — there are two distinct items pytest.fail("Expected a list of output items, got a single dict") + def test_responses_api_strips_status_from_function_call(self): + """Responses API function_call items have a 'status' field that OpenAI + rejects when sent back as input ('Unknown parameter: input[N].status'). + It must be stripped before the item is stored in conversation history.""" + resp = _MockResponse( + output=[_MockFunctionCall("my_tool", '{"x": 1}', call_id="call_xyz")] + ) + result = _convert_raw_response_to_dict(resp) + assert isinstance(result, list) + for item in result: + assert ( + "status" not in item + ), f"'status' must be stripped from Responses API items: {item}" + + def test_responses_api_strips_status_from_message(self): + """Responses API message items also carry 'status'; it must be stripped.""" + resp = _MockResponse(output=[_MockOutputMessage("Hello")]) + result = _convert_raw_response_to_dict(resp) + assert isinstance(result, list) + for item in result: + assert ( + "status" not in item + ), f"'status' must be stripped from Responses API items: {item}" + # ─────────────────────────────────────────────────────────────────────────── # _get_tool_requests (lines 61-86) @@ -932,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api(): ep.execution_stats_lock = threading.Lock() ns = MagicMock(error=None) ep.on_node_execution = AsyncMock(return_value=ns) + # Mock charge_node_usage (called after successful tool execution). + # Must be AsyncMock because it is async and is awaited in + # _execute_single_tool_with_manager — a plain MagicMock would return a + # non-awaitable tuple and TypeError out, then be silently swallowed by + # the orchestrator's catch-all. + ep.charge_node_usage = AsyncMock(return_value=(0, 0)) with patch("backend.blocks.llm.llm_call", llm_mock), patch.object( block, "_create_tool_node_signatures", return_value=tool_sigs diff --git a/autogpt_platform/backend/backend/blocks/text_to_speech_block.py b/autogpt_platform/backend/backend/blocks/text_to_speech_block.py index a408c8772f..1860d10d24 100644 --- a/autogpt_platform/backend/backend/blocks/text_to_speech_block.py +++ b/autogpt_platform/backend/backend/blocks/text_to_speech_block.py @@ -13,6 +13,7 @@ from backend.data.model import ( APIKeyCredentials, CredentialsField, CredentialsMetaInput, + NodeExecutionStats, SchemaField, ) from backend.integrations.providers import ProviderName @@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block): input_data.text, input_data.voice_id, ) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(input_data.text)), + provider_cost_type="characters", + ) + ) yield "mp3_url", api_response["OutputUri"] diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py new file mode 100644 index 0000000000..15a77dde8a --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py @@ -0,0 +1,230 @@ +"""Extended-thinking wire support for the baseline (OpenRouter) path. + +Anthropic routes on OpenRouter expose extended thinking through +non-OpenAI extension fields that the OpenAI Python SDK doesn't model: + +* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``. +* ``reasoning_content`` — DeepSeek / some OpenRouter routes. +* ``reasoning_details`` — structured list shipped with the unified + ``reasoning`` request param. + +This module keeps the wire-level concerns in one place: + +* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off + ``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` + + ``isinstance`` duck-typing at the call site. +* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for + one streaming round and emits ``StreamReasoning*`` events so the caller + only has to plumb the events into its pending queue. +* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the + OpenAI client call. Returns ``None`` on non-Anthropic routes. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamBaseResponse, + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + +logger = logging.getLogger(__name__) + + +_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"}) + + +class ReasoningDetail(BaseModel): + """One entry in OpenRouter's ``reasoning_details`` list. + + OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` / + ``"reasoning.encrypted"`` entries. Only the first two carry + user-visible text; encrypted entries are opaque and omitted from the + rendered collapse. Unknown future types are tolerated (``extra="ignore"``) + so an upstream addition doesn't crash the stream — but their ``text`` / + ``summary`` fields are NOT surfaced because they may carry provider + metadata rather than user-visible reasoning (see + :attr:`visible_text`). + """ + + model_config = ConfigDict(extra="ignore") + + type: str | None = None + text: str | None = None + summary: str | None = None + + @property + def visible_text(self) -> str: + """Return the human-readable text for this entry, or ``""``. + + Only entries with a recognised reasoning type (``reasoning.text`` / + ``reasoning.summary``) surface text; unknown or encrypted types + return an empty string even if they carry a ``text`` / + ``summary`` field, to guard against future provider metadata + being rendered as reasoning in the UI. Entries missing a + ``type`` are treated as text (pre-``reasoning_details`` OpenRouter + payloads omit the field). + """ + if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES: + return "" + return self.text or self.summary or "" + + +class OpenRouterDeltaExtension(BaseModel): + """Non-OpenAI fields OpenRouter adds to streaming deltas. + + Instantiate via :meth:`from_delta` which pulls the extension dict off + ``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that + aren't part of the declared schema) and validates it through this + model. That keeps the parser honest — malformed entries surface as + validation errors rather than silent ``None``-coalesce bugs — and + avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline + extractor relied on. + """ + + model_config = ConfigDict(extra="ignore") + + reasoning: str | None = None + reasoning_content: str | None = None + reasoning_details: list[ReasoningDetail] = Field(default_factory=list) + + @classmethod + def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension": + """Build an extension view from ``delta.model_extra``. + + Malformed provider payloads (e.g. ``reasoning_details`` shipped as + a string rather than a list) surface as a ``ValidationError`` which + is logged and swallowed — returning an empty extension so the rest + of the stream (valid text / tool calls) keeps flowing. An optional + feature's corrupted wire data must never abort the whole stream. + """ + try: + return cls.model_validate(delta.model_extra or {}) + except ValidationError as exc: + logger.warning( + "[Baseline] Dropping malformed OpenRouter reasoning payload: %s", + exc, + ) + return cls() + + def visible_text(self) -> str: + """Concatenated reasoning text, pulled from whichever channel is set. + + Priority: the legacy ``reasoning`` string, then DeepSeek's + ``reasoning_content``, then the concatenation of text-bearing + entries in ``reasoning_details``. Only one channel is set per + provider in practice; the priority order just makes the fallback + deterministic if a provider ever emits multiple. + """ + if self.reasoning: + return self.reasoning + if self.reasoning_content: + return self.reasoning_content + return "".join(d.visible_text for d in self.reasoning_details) + + +def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None: + """Build the ``extra_body["reasoning"]`` fragment for the OpenAI client. + + Returns ``None`` for non-Anthropic routes (other OpenRouter providers + ignore the field but we skip it anyway to keep the payload minimal) + and for ``max_thinking_tokens <= 0`` (operator kill switch). + """ + # Imported lazily to avoid pulling service.py at module load — service.py + # imports this module, and the lazy import keeps the dependency one-way. + from backend.copilot.baseline.service import _is_anthropic_model + + if not _is_anthropic_model(model) or max_thinking_tokens <= 0: + return None + return {"reasoning": {"max_tokens": max_thinking_tokens}} + + +class BaselineReasoningEmitter: + """Owns the reasoning block lifecycle for one streaming round. + + Two concerns live here, both driven by the same state machine: + + 1. **Wire events.** The AI SDK v6 wire format pairs every + ``reasoning-start`` with a matching ``reasoning-end`` and treats + reasoning / text / tool-use as distinct UI parts that must not + interleave. + 2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in + ``session.messages`` are what + ``convertChatSessionToUiMessages.ts`` folds into the assistant + bubble as ``{type: "reasoning"}`` UI parts on reload and on + ``useHydrateOnStreamEnd`` swaps. Without them the live-streamed + reasoning parts get overwritten by the hydrated (reasoning-less) + message list the moment the stream ends. Mirrors the SDK path's + ``acc.reasoning_response`` pattern so both routes render the same + way on reload. + + Pass ``session_messages`` to enable persistence; omit for pure + wire-emission (tests, scratch callers). On first reasoning delta a + fresh ``ChatMessage(role="reasoning")`` is appended and mutated + in-place as further deltas arrive; :meth:`close` drops the reference + but leaves the appended row intact. + """ + + def __init__( + self, + session_messages: list[ChatMessage] | None = None, + ) -> None: + self._block_id: str = str(uuid.uuid4()) + self._open: bool = False + self._session_messages = session_messages + self._current_row: ChatMessage | None = None + + @property + def is_open(self) -> bool: + return self._open + + def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]: + """Return events for the reasoning text carried by *delta*. + + Empty list when the chunk carries no reasoning payload, so this is + safe to call on every chunk without guarding at the call site. + Persistence (when a session message list is attached) happens in + lockstep with emission so the row's content stays equal to the + concatenated deltas at every delta boundary. + """ + ext = OpenRouterDeltaExtension.from_delta(delta) + text = ext.visible_text() + if not text: + return [] + events: list[StreamBaseResponse] = [] + if not self._open: + events.append(StreamReasoningStart(id=self._block_id)) + self._open = True + if self._session_messages is not None: + self._current_row = ChatMessage(role="reasoning", content="") + self._session_messages.append(self._current_row) + events.append(StreamReasoningDelta(id=self._block_id, delta=text)) + if self._current_row is not None: + self._current_row.content = (self._current_row.content or "") + text + return events + + def close(self) -> list[StreamBaseResponse]: + """Emit ``StreamReasoningEnd`` for the open block (if any) and rotate. + + Idempotent — returns ``[]`` when no block is open. The id rotation + guarantees the next reasoning block starts with a fresh id rather + than reusing one already closed on the wire. The persisted row is + not removed — it stays in ``session_messages`` as the durable + record of what was reasoned. + """ + if not self._open: + return [] + event = StreamReasoningEnd(id=self._block_id) + self._open = False + self._block_id = str(uuid.uuid4()) + self._current_row = None + return [event] diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py new file mode 100644 index 0000000000..df64086d5f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py @@ -0,0 +1,281 @@ +"""Tests for the baseline reasoning extension module. + +Covers the typed OpenRouter delta parser, the stateful emitter, and the +``extra_body`` builder. The emitter is tested against real +``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the +parser relies on is exercised end-to-end. +""" + +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + OpenRouterDeltaExtension, + ReasoningDetail, + reasoning_extra_body, +) +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + + +def _delta(**extra) -> ChoiceDelta: + """Build a ChoiceDelta with the given extension fields on ``model_extra``.""" + return ChoiceDelta.model_validate({"role": "assistant", **extra}) + + +class TestReasoningDetail: + def test_visible_text_prefers_text(self): + d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored") + assert d.visible_text == "hi" + + def test_visible_text_falls_back_to_summary(self): + d = ReasoningDetail(type="reasoning.summary", summary="tldr") + assert d.visible_text == "tldr" + + def test_visible_text_empty_for_encrypted(self): + d = ReasoningDetail(type="reasoning.encrypted") + assert d.visible_text == "" + + def test_unknown_fields_are_ignored(self): + # OpenRouter may add new fields in future payloads — they shouldn't + # cause validation errors. + d = ReasoningDetail.model_validate( + {"type": "reasoning.future", "text": "x", "signature": "opaque"} + ) + assert d.text == "x" + + def test_visible_text_empty_for_unknown_type(self): + # Unknown types may carry provider metadata that must not render as + # user-visible reasoning — regardless of whether a text/summary is + # present. Only ``reasoning.text`` / ``reasoning.summary`` surface. + d = ReasoningDetail(type="reasoning.future", text="leaked metadata") + assert d.visible_text == "" + + def test_visible_text_surfaces_text_when_type_missing(self): + # Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat + # them as text so we don't regress the legacy structured shape. + d = ReasoningDetail(text="plain") + assert d.visible_text == "plain" + + +class TestOpenRouterDeltaExtension: + def test_from_delta_reads_model_extra(self): + delta = _delta(reasoning="step one") + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning == "step one" + + def test_visible_text_legacy_string(self): + ext = OpenRouterDeltaExtension(reasoning="plain text") + assert ext.visible_text() == "plain text" + + def test_visible_text_deepseek_alias(self): + ext = OpenRouterDeltaExtension(reasoning_content="alt channel") + assert ext.visible_text() == "alt channel" + + def test_visible_text_structured_details_concat(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.text", text="hello "), + ReasoningDetail(type="reasoning.text", text="world"), + ] + ) + assert ext.visible_text() == "hello world" + + def test_visible_text_skips_encrypted(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.encrypted"), + ReasoningDetail(type="reasoning.text", text="visible"), + ] + ) + assert ext.visible_text() == "visible" + + def test_visible_text_empty_when_all_channels_blank(self): + ext = OpenRouterDeltaExtension() + assert ext.visible_text() == "" + + def test_empty_delta_produces_empty_extension(self): + ext = OpenRouterDeltaExtension.from_delta(_delta()) + assert ext.reasoning is None + assert ext.reasoning_content is None + assert ext.reasoning_details == [] + + def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog): + # A malformed payload (e.g. reasoning_details shipped as a string + # rather than a list) must not abort the stream — log it and + # return an empty extension so valid text/tool events keep flowing. + # A plain mock is used here because ``from_delta`` only reads + # ``delta.model_extra`` — avoids reaching into pydantic internals + # (``__pydantic_extra__``) that could be renamed across versions. + from unittest.mock import MagicMock + + delta = MagicMock(spec=ChoiceDelta) + delta.model_extra = {"reasoning_details": "not a list"} + with caplog.at_level("WARNING"): + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning_details == [] + assert ext.visible_text() == "" + assert any("malformed" in r.message.lower() for r in caplog.records) + + def test_unknown_typed_entry_with_text_is_not_surfaced(self): + # Regression: the legacy extractor emitted any entry with a + # ``text`` or ``summary`` field. The typed parser now filters on + # the recognised types so future provider metadata can't leak + # into the reasoning collapse. + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.future", text="provider metadata"), + ReasoningDetail(type="reasoning.text", text="real"), + ] + ) + assert ext.visible_text() == "real" + + +class TestReasoningExtraBody: + def test_anthropic_route_returns_fragment(self): + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == { + "reasoning": {"max_tokens": 4096} + } + + def test_direct_claude_model_id_still_matches(self): + assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == { + "reasoning": {"max_tokens": 2048} + } + + def test_non_anthropic_route_returns_none(self): + assert reasoning_extra_body("openai/gpt-4o", 4096) is None + assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None + + def test_zero_max_tokens_kill_switch(self): + # Operator kill switch: ``max_thinking_tokens <= 0`` disables the + # ``reasoning`` extra_body fragment even on an Anthropic route. + # Lets us silence reasoning without dropping the SDK path's budget. + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None + + +class TestBaselineReasoningEmitter: + def test_first_text_delta_emits_start_then_delta(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="thinking")) + + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningStart) + assert isinstance(events[1], StreamReasoningDelta) + assert events[0].id == events[1].id + assert events[1].delta == "thinking" + assert emitter.is_open is True + + def test_subsequent_deltas_reuse_block_id_without_new_start(self): + emitter = BaselineReasoningEmitter() + first = emitter.on_delta(_delta(reasoning="a")) + second = emitter.on_delta(_delta(reasoning="b")) + + assert any(isinstance(e, StreamReasoningStart) for e in first) + assert all(not isinstance(e, StreamReasoningStart) for e in second) + assert len(second) == 1 + assert isinstance(second[0], StreamReasoningDelta) + assert first[0].id == second[0].id + + def test_empty_delta_emits_nothing(self): + emitter = BaselineReasoningEmitter() + assert emitter.on_delta(_delta(content="hello")) == [] + assert emitter.is_open is False + + def test_close_emits_end_and_rotates_id(self): + emitter = BaselineReasoningEmitter() + # Capture the block id from the wire event rather than reaching + # into emitter internals — the id on the emitted Start/Delta is + # what the frontend actually receives. + start_events = emitter.on_delta(_delta(reasoning="x")) + first_id = start_events[0].id + + events = emitter.close() + assert len(events) == 1 + assert isinstance(events[0], StreamReasoningEnd) + assert events[0].id == first_id + assert emitter.is_open is False + # Next reasoning uses a fresh id. + new_events = emitter.on_delta(_delta(reasoning="y")) + assert isinstance(new_events[0], StreamReasoningStart) + assert new_events[0].id != first_id + + def test_close_is_idempotent(self): + emitter = BaselineReasoningEmitter() + assert emitter.close() == [] + emitter.on_delta(_delta(reasoning="x")) + assert len(emitter.close()) == 1 + assert emitter.close() == [] + + def test_structured_details_round_trip(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta( + _delta( + reasoning_details=[ + {"type": "reasoning.text", "text": "plan: "}, + {"type": "reasoning.summary", "summary": "do the thing"}, + ] + ) + ) + deltas = [e for e in events if isinstance(e, StreamReasoningDelta)] + assert len(deltas) == 1 + assert deltas[0].delta == "plan: do the thing" + + +class TestReasoningPersistence: + """The persistence contract: without ``role="reasoning"`` rows in + session.messages, useHydrateOnStreamEnd overwrites the live-streamed + reasoning parts and the Reasoning collapse vanishes. Every delta + must be reflected in the persisted row the moment it's emitted.""" + + def test_session_row_appended_on_first_delta(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + assert session == [] + emitter.on_delta(_delta(reasoning="hi")) + assert len(session) == 1 + assert session[0].role == "reasoning" + assert session[0].content == "hi" + + def test_subsequent_deltas_mutate_same_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="part one ")) + emitter.on_delta(_delta(reasoning="part two")) + + assert len(session) == 1 + assert session[0].content == "part one part two" + + def test_close_keeps_row_in_session(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="thought")) + emitter.close() + + assert len(session) == 1 + assert session[0].content == "thought" + + def test_second_reasoning_block_appends_new_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="first")) + emitter.close() + emitter.on_delta(_delta(reasoning="second")) + + assert len(session) == 2 + assert [m.content for m in session] == ["first", "second"] + + def test_no_session_means_no_persistence(self): + """Emitter without attached session list emits wire events only.""" + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="pure wire")) + assert len(events) == 2 # start + delta, no crash + # Nothing else to assert — just proves None session is supported. diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 379686b64d..474a6834b1 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -7,26 +7,55 @@ shared tool registry as the SDK path. """ import asyncio +import base64 import logging +import math +import os +import re +import shutil +import tempfile import uuid -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Mapping, Sequence from dataclasses import dataclass, field from functools import partial -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import orjson from langfuse import propagate_attributes +from openai.types import CompletionUsage from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam +from openai.types.completion_usage import PromptTokensDetails +from opentelemetry import trace as otel_trace -from backend.copilot.context import set_execution_context +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + reasoning_extra_body, +) +from backend.copilot.builder_context import ( + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) +from backend.copilot.config import CopilotLlmModel, CopilotMode +from backend.copilot.context import get_workspace_manager, set_execution_context +from backend.copilot.graphiti.config import is_enabled_for_user from backend.copilot.model import ( ChatMessage, ChatSession, get_chat_session, - update_session_title, + maybe_append_user_message, upsert_chat_session, ) -from backend.copilot.prompting import get_baseline_supplement +from backend.copilot.pending_message_helpers import ( + combine_pending_with_current, + drain_pending_safe, + persist_pending_as_user_rows, + persist_session_safe, +) +from backend.copilot.pending_messages import ( + drain_pending_messages, + format_pending_as_user_message, +) +from backend.copilot.prompting import SHARED_TOOL_NOTES, get_graphiti_supplement from backend.copilot.response_model import ( StreamBaseResponse, StreamError, @@ -44,13 +73,31 @@ from backend.copilot.response_model import ( ) from backend.copilot.service import ( _build_system_prompt, - _generate_session_title, _get_openai_client, + _update_title_async, config, + inject_user_context, + strip_user_context_tags, ) +from backend.copilot.session_cleanup import prune_orphan_tool_calls +from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper from backend.copilot.token_tracking import persist_and_record_usage from backend.copilot.tools import execute_tool, get_available_tools from backend.copilot.tracking import track_user_message +from backend.copilot.transcript import ( + STOP_REASON_END_TURN, + STOP_REASON_TOOL_USE, + TranscriptDownload, + detect_gap, + download_transcript, + extract_context_messages, + strip_for_upload, + upload_transcript, + validate_transcript, +) +from backend.copilot.transcript_builder import TranscriptBuilder +from backend.data.db_accessors import chat_db +from backend.util import json as util_json from backend.util.exceptions import NotFoundError from backend.util.prompt import ( compress_context, @@ -64,6 +111,9 @@ from backend.util.tool_call_loop import ( tool_call_loop, ) +if TYPE_CHECKING: + from backend.copilot.permissions import CopilotPermissions + logger = logging.getLogger(__name__) # Set to hold background tasks to prevent garbage collection @@ -72,6 +122,214 @@ _background_tasks: set[asyncio.Task[Any]] = set() # Maximum number of tool-call rounds before forcing a text response. _MAX_TOOL_ROUNDS = 30 +# Max seconds to wait for transcript upload in the finally block before +# letting it continue as a background task (tracked in _background_tasks). +_TRANSCRIPT_UPLOAD_TIMEOUT_S = 5 + +# MIME types that can be embedded as vision content blocks (OpenAI format). +_VISION_MIME_TYPES = frozenset({"image/png", "image/jpeg", "image/gif", "image/webp"}) + + +# Max size for embedding images directly in the user message (20 MiB raw). +_MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024 + +# Matches characters unsafe for filenames. +_UNSAFE_FILENAME = re.compile(r"[^\w.\-]") + +# OpenRouter-specific extra_body flag that embeds the real generation cost +# into the final usage chunk. Module-level constant so we don't reallocate +# an identical dict on every streaming call. +_OPENROUTER_INCLUDE_USAGE_COST = {"usage": {"include": True}} + + +def _extract_usage_cost(usage: CompletionUsage) -> float | None: + """Return the provider-reported USD cost on a streaming usage chunk. + + OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible usage + object when the request body includes ``usage: {"include": True}``. + The OpenAI SDK's typed ``CompletionUsage`` does not declare it, so we + read it off ``model_extra`` (the pydantic v2 container for extras) to + keep the access fully typed — no ``getattr``. + + Returns ``None`` when the field is absent, explicitly null, + non-numeric, non-finite, or negative. Invalid values (including + present-but-null) are logged here — they indicate a provider bug + worth chasing; plain absences are silent so the caller can dedupe + the "missing cost" warning per stream. + """ + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[Baseline] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[Baseline] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[Baseline] usage.cost is non-finite or negative: %r", val) + return None + return val + + +def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: + """Return cache-write token count from an OpenAI-compatible + ``PromptTokensDetails``, handling provider-specific field names and + SDK-version shape differences. + + Two shapes we care about: + + - **OpenRouter** (our primary baseline provider) streams the cache-write + count as ``cache_write_tokens``. Newer ``openai-python`` versions + declare this as a typed attribute on ``PromptTokensDetails``; older + versions expose it only in ``model_extra``. Verified empirically: + cold-cache request returns ``cache_write_tokens`` > 0, warm-cache + request returns ``cached_tokens`` > 0 and ``cache_write_tokens`` = 0. + - **Direct Anthropic API** uses ``cache_creation_input_tokens`` — + never a typed attribute on the OpenAI SDK, always lives in + ``model_extra``. + + Lookup order: typed attr → ``model_extra`` (OpenRouter) → ``model_extra`` + (Anthropic-native). ``getattr`` handles both the typed-attr case + (newer SDK) and the no-such-attr case (older SDK) — we can't only use + ``model_extra`` because when the field is typed it's filtered out of + ``model_extra``, leaving us at 0 on the modern happy path. + """ + typed_val = getattr(ptd, "cache_write_tokens", None) + if typed_val: + return int(typed_val) + extras = ptd.model_extra or {} + return int( + extras.get("cache_write_tokens") + or extras.get("cache_creation_input_tokens") + or 0 + ) + + +async def _prepare_baseline_attachments( + file_ids: list[str], + user_id: str, + session_id: str, + working_dir: str, +) -> tuple[str, list[dict[str, Any]]]: + """Download workspace files and prepare them for the baseline LLM. + + Images become OpenAI-format vision content blocks. Non-image files are + saved to *working_dir* so tool handlers can access them. + + Returns ``(hint_text, image_blocks)``. + """ + if not file_ids or not user_id: + return "", [] + + try: + manager = await get_workspace_manager(user_id, session_id) + except Exception: + logger.warning( + "Failed to create workspace manager for file attachments", + exc_info=True, + ) + return "", [] + + image_blocks: list[dict[str, Any]] = [] + file_descriptions: list[str] = [] + + for fid in file_ids: + try: + file_info = await manager.get_file_info(fid) + if file_info is None: + continue + content = await manager.read_file_by_id(fid) + mime = (file_info.mime_type or "").split(";")[0].strip().lower() + + if mime in _VISION_MIME_TYPES and len(content) <= _MAX_INLINE_IMAGE_BYTES: + b64 = base64.b64encode(content).decode("ascii") + image_blocks.append( + { + "type": "image", + "source": {"type": "base64", "media_type": mime, "data": b64}, + } + ) + file_descriptions.append( + f"- {file_info.name} ({mime}, " + f"{file_info.size_bytes:,} bytes) [embedded as image]" + ) + else: + safe = _UNSAFE_FILENAME.sub("_", file_info.name) or "file" + candidate = os.path.join(working_dir, safe) + if os.path.exists(candidate): + stem, ext = os.path.splitext(safe) + idx = 1 + while os.path.exists(candidate): + candidate = os.path.join(working_dir, f"{stem}_{idx}{ext}") + idx += 1 + with open(candidate, "wb") as f: + f.write(content) + file_descriptions.append( + f"- {file_info.name} ({mime}, " + f"{file_info.size_bytes:,} bytes) saved to " + f"{os.path.basename(candidate)}" + ) + except Exception: + logger.warning("Failed to prepare file %s", fid[:12], exc_info=True) + + if not file_descriptions: + return "", [] + + noun = "file" if len(file_descriptions) == 1 else "files" + has_non_images = len(file_descriptions) > len(image_blocks) + read_hint = ( + " Use the read_workspace_file tool to view non-image files." + if has_non_images + else "" + ) + hint = ( + f"\n[The user attached {len(file_descriptions)} {noun}.{read_hint}\n" + + "\n".join(file_descriptions) + + "]" + ) + return hint, image_blocks + + +def _filter_tools_by_permissions( + tools: list[ChatCompletionToolParam], + permissions: "CopilotPermissions", +) -> list[ChatCompletionToolParam]: + """Filter OpenAI-format tools based on CopilotPermissions. + + Uses short tool names (the ``function.name`` field) to compute the + effective allowed set, then keeps only matching tools. + """ + from backend.copilot.permissions import all_known_tool_names + + if permissions.is_empty(): + return tools + + all_tools = all_known_tool_names() + effective = permissions.effective_allowed_tools(all_tools) + + return [ + t + for t in tools + if t.get("function", {}).get("name") in effective # type: ignore[union-attr] + ] + + +def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str: + """Pick the model for the baseline path based on the per-request tier. + + The baseline (fast) and SDK (extended thinking) paths now share the + same tier-based model resolution — only the *path* differs between + "fast" and "extended_thinking". ``'advanced'`` → Opus; + ``'standard'`` / ``None`` → the config default (Sonnet). + """ + from backend.copilot.service import resolve_chat_model + + return resolve_chat_model(tier) + @dataclass class _BaselineStreamState: @@ -81,12 +339,171 @@ class _BaselineStreamState: can be module-level functions instead of deeply nested closures. """ + model: str = "" pending_events: list[StreamBaseResponse] = field(default_factory=list) assistant_text: str = "" text_block_id: str = field(default_factory=lambda: str(uuid.uuid4())) text_started: bool = False + reasoning_emitter: BaselineReasoningEmitter = field(init=False) turn_prompt_tokens: int = 0 turn_completion_tokens: int = 0 + turn_cache_read_tokens: int = 0 + turn_cache_creation_tokens: int = 0 + cost_usd: float | None = None + # Tracks whether we've already warned about a missing `cost` field in + # the usage chunk this stream, so non-OpenRouter providers don't + # generate one warning per streaming call. + cost_missing_logged: bool = False + thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) + # MUTATE in place only — ``__post_init__`` hands this list reference to + # ``BaselineReasoningEmitter`` so reasoning rows can be appended as + # deltas stream in. Reassigning (``state.session_messages = [...]``) + # would silently detach the emitter from the new list. + session_messages: list[ChatMessage] = field(default_factory=list) + # Tracks how much of ``assistant_text`` has already been flushed to + # ``session.messages`` via mid-loop pending drains, so the ``finally`` + # block only appends the *new* assistant text (avoiding duplication of + # round-1 text when round-1 entries were cleared from session_messages). + _flushed_assistant_text_len: int = 0 + # Memoised system-message dict with cache_control applied. The system + # prompt is static within a session, so we build it once on the first + # LLM round and reuse the same dict on subsequent rounds — avoiding + # an O(N) dict-copy of the growing ``messages`` list on every tool-call + # iteration. ``None`` means "not yet computed" (or the first message + # wasn't a system role, so no marking applies). + cached_system_message: dict[str, Any] | None = None + + def __post_init__(self) -> None: + # Wire the reasoning emitter to ``session_messages`` so it can + # append ``role="reasoning"`` rows as reasoning streams in — the + # frontend's ``convertChatSessionToUiMessages`` relies on these + # rows to render the Reasoning collapse after the AI SDK's + # stream-end hydrate swaps in the DB-backed message list. + self.reasoning_emitter = BaselineReasoningEmitter(self.session_messages) + + +def _is_anthropic_model(model: str) -> bool: + """Return True if *model* routes to Anthropic (native or via OpenRouter). + + Cache-control markers on message content + the ``anthropic-beta`` header + are Anthropic-specific. OpenAI rejects the unknown ``cache_control`` + field with a 400 ("Extra inputs are not permitted") and Grok / other + providers behave similarly. OpenRouter strips unknown headers but + passes through ``cache_control`` on the body regardless of provider — + which would also fail when OpenRouter routes to a non-Anthropic model. + + Examples that return True: + - ``anthropic/claude-sonnet-4-6`` (OpenRouter route) + - ``claude-3-5-sonnet-20241022`` (direct Anthropic API) + - ``anthropic.claude-3-5-sonnet`` (Bedrock-style) + + False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4`` + etc. + """ + lowered = model.lower() + return "claude" in lowered or lowered.startswith("anthropic") + + +def _fresh_ephemeral_cache_control() -> dict[str, str]: + """Return a FRESH ephemeral ``cache_control`` dict each call. + + The ``ttl`` is sourced from :attr:`ChatConfig.baseline_prompt_cache_ttl` + (default ``1h``) so the static prefix stays warm across many users' + requests in the same workspace cache. Anthropic caches are keyed + per-workspace, so every copilot user reading the same system prompt + hits the same cached entry. + + Using a shared module-level dict would let any downstream mutation + (e.g. the OpenAI SDK normalising fields in-place) poison every future + request's marker. Construction is O(1) so the safety margin is free. + """ + return {"type": "ephemeral", "ttl": config.baseline_prompt_cache_ttl} + + +def _fresh_anthropic_caching_headers() -> dict[str, str]: + """Return a FRESH ``extra_headers`` dict requesting the Anthropic + prompt-caching beta. + + Same reasoning as :func:`_fresh_ephemeral_cache_control`: never hand a + shared module-level dict to third-party SDKs. OpenRouter auto-forwards + cache_control for Anthropic routes without this header, but passing it + makes the intent unambiguous on-wire and is a no-op for non-Anthropic + providers (unknown headers are dropped). + """ + return {"anthropic-beta": "prompt-caching-2024-07-31"} + + +def _mark_tools_with_cache_control( + tools: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *tools* with ``cache_control`` on the last entry. + + Marking the last tool is a cache breakpoint that covers the whole tool + schema block as a cacheable prefix segment. Extracted from + :func:`_mark_system_message_with_cache_control` so callers can precompute + the marked tool list once per session — the tool set is static within a + request and the ~43 dict-copies would otherwise run on every LLM round + in the tool-call loop. + + **Only call this for Anthropic model routes.** Non-Anthropic providers + (OpenAI, Grok, Gemini) reject the unknown ``cache_control`` field with + a 400 schema validation error. Gate via :func:`_is_anthropic_model`. + """ + cached: list[dict[str, Any]] = [dict(t) for t in tools] + if cached: + cached[-1] = { + **cached[-1], + "cache_control": _fresh_ephemeral_cache_control(), + } + return cached + + +def _build_cached_system_message( + system_message: Mapping[str, Any], +) -> dict[str, Any]: + """Return a copy of *system_message* with ``cache_control`` applied. + + Anthropic's cache uses prefix-match with up to 4 explicit breakpoints. + Combined with the last-tool marker this gives two cache segments — the + system block alone, and system+all-tools — so requests that share only + the system prefix still get a partial cache hit. + + The system message is rebuilt via spread (``{**original, ...}``) so any + unknown fields the caller set (e.g. ``name``) survive the transformation. + Non-Anthropic models silently ignore the markers. + + Returns the original dict (shallow-copied) unchanged when the content + shape is unsupported (missing / non-string / empty) — callers should + splice it into the message list as-is in that case. + """ + sys_copy = dict(system_message) + sys_content = sys_copy.get("content") + if isinstance(sys_content, str) and sys_content: + sys_copy["content"] = [ + { + "type": "text", + "text": sys_content, + "cache_control": _fresh_ephemeral_cache_control(), + } + ] + return sys_copy + + +def _mark_system_message_with_cache_control( + messages: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *messages* with ``cache_control`` on the system block. + + Thin wrapper around :func:`_build_cached_system_message` that preserves + the original list shape. Prefer the memoised path in + ``_baseline_llm_caller`` (which builds the cached system dict once per + session) for hot-loop callers; this function is retained for call sites + outside the tool-call loop where per-call copying is acceptable. + """ + cached_messages: list[dict[str, Any]] = [dict(m) for m in messages] + if cached_messages and cached_messages[0].get("role") == "system": + cached_messages[0] = _build_cached_system_message(cached_messages[0]) + return cached_messages async def _baseline_llm_caller( @@ -100,70 +517,177 @@ async def _baseline_llm_caller( Extracted from ``stream_chat_completion_baseline`` for readability. """ state.pending_events.append(StreamStartStep()) + # Fresh thinking-strip state per round so a malformed unclosed + # block in one LLM call cannot silently drop content in the next. + state.thinking_stripper = _ThinkingStripper() round_text = "" try: client = _get_openai_client() - typed_messages = cast(list[ChatCompletionMessageParam], messages) - if tools: - typed_tools = cast(list[ChatCompletionToolParam], tools) - response = await client.chat.completions.create( - model=config.model, - messages=typed_messages, - tools=typed_tools, - stream=True, - stream_options={"include_usage": True}, - ) + # Cache markers are Anthropic-specific. For OpenAI/Grok/other + # providers, leaving them on would trigger a 400 ("Extra inputs + # are not permitted" on cache_control). Tools were precomputed + # in stream_chat_completion_baseline via _mark_tools_with_cache_control + # (only when the model was Anthropic), so on non-Anthropic routes + # tools ship without cache_control on the last entry too. + # + # `extra_body` `usage.include=true` asks OpenRouter to embed the real + # generation cost into the final usage chunk — required by the + # cost-based rate limiter in routes.py. Separate from the Anthropic + # caching headers, always sent. + is_anthropic = _is_anthropic_model(state.model) + if is_anthropic: + # Build the cached system dict once per session and splice it in + # on each round. The full ``messages`` list grows with every + # tool call, so copying the entire list just to mutate index 0 + # scales with conversation length (sentry flagged this); this + # splice touches only list slots, not message contents. + if ( + state.cached_system_message is None + and messages + and messages[0].get("role") == "system" + ): + state.cached_system_message = _build_cached_system_message(messages[0]) + if state.cached_system_message is not None and messages: + final_messages = [state.cached_system_message, *messages[1:]] + else: + final_messages = messages + extra_headers = _fresh_anthropic_caching_headers() else: - response = await client.chat.completions.create( - model=config.model, - messages=typed_messages, - stream=True, - stream_options={"include_usage": True}, - ) + final_messages = messages + extra_headers = None + typed_messages = cast(list[ChatCompletionMessageParam], final_messages) + extra_body: dict[str, Any] = dict(_OPENROUTER_INCLUDE_USAGE_COST) + reasoning_param = reasoning_extra_body( + state.model, config.claude_agent_max_thinking_tokens + ) + if reasoning_param: + extra_body.update(reasoning_param) + create_kwargs: dict[str, Any] = { + "model": state.model, + "messages": typed_messages, + "stream": True, + "stream_options": {"include_usage": True}, + "extra_body": extra_body, + } + if extra_headers: + create_kwargs["extra_headers"] = extra_headers + if tools: + create_kwargs["tools"] = cast(list[ChatCompletionToolParam], list(tools)) + response = await client.chat.completions.create(**create_kwargs) tool_calls_by_index: dict[int, dict[str, str]] = {} - async for chunk in response: - if chunk.usage: - state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 - state.turn_completion_tokens += chunk.usage.completion_tokens or 0 + # Iterate under an inner try/finally so early exits (cancel, tool-call + # break, exception) always release the underlying httpx connection. + # Without this, openai.AsyncStream leaks the streaming response and + # the TCP socket ends up in CLOSE_WAIT until the process exits. + try: + async for chunk in response: + if chunk.usage: + state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 + state.turn_completion_tokens += chunk.usage.completion_tokens or 0 + ptd = chunk.usage.prompt_tokens_details + if ptd: + state.turn_cache_read_tokens += ptd.cached_tokens or 0 + state.turn_cache_creation_tokens += ( + _extract_cache_creation_tokens(ptd) + ) + cost = _extract_usage_cost(chunk.usage) + if cost is not None: + state.cost_usd = (state.cost_usd or 0.0) + cost + elif ( + "cost" not in (chunk.usage.model_extra or {}) + and not state.cost_missing_logged + ): + # Field absent (non-OpenRouter route, or OpenRouter + # misconfigured) — warn once per stream so error + # monitoring picks up persistent misses without + # flooding. Invalid values already logged inside + # _extract_usage_cost, so no duplicate warning here. + logger.warning( + "[Baseline] usage chunk missing cost (model=%s, " + "prompt=%s, completion=%s) — rate-limit will " + "skip this call", + state.model, + chunk.usage.prompt_tokens, + chunk.usage.completion_tokens, + ) + state.cost_missing_logged = True - delta = chunk.choices[0].delta if chunk.choices else None - if not delta: - continue + delta = chunk.choices[0].delta if chunk.choices else None + if not delta: + continue - if delta.content: - if not state.text_started: - state.pending_events.append(StreamTextStart(id=state.text_block_id)) - state.text_started = True - round_text += delta.content - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=delta.content) - ) + state.pending_events.extend(state.reasoning_emitter.on_delta(delta)) - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": "", - "name": "", - "arguments": "", - } - entry = tool_calls_by_index[idx] - if tc.id: - entry["id"] = tc.id - if tc.function and tc.function.name: - entry["name"] = tc.function.name - if tc.function and tc.function.arguments: - entry["arguments"] += tc.function.arguments + if delta.content: + # Text and reasoning must not interleave on the wire — the + # AI SDK maps distinct start/end pairs to distinct UI + # parts. Close any open reasoning block before emitting + # the first text delta of this run. + state.pending_events.extend(state.reasoning_emitter.close()) + emit = state.thinking_stripper.process(delta.content) + if emit: + if not state.text_started: + state.pending_events.append( + StreamTextStart(id=state.text_block_id) + ) + state.text_started = True + round_text += emit + state.pending_events.append( + StreamTextDelta(id=state.text_block_id, delta=emit) + ) - # Close text block + if delta.tool_calls: + # Same rule as the text branch: close any open reasoning + # block before a tool_use starts so the AI SDK treats + # reasoning and tool-use as distinct parts. + state.pending_events.extend(state.reasoning_emitter.close()) + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": "", + "name": "", + "arguments": "", + } + entry = tool_calls_by_index[idx] + if tc.id: + entry["id"] = tc.id + if tc.function and tc.function.name: + entry["name"] = tc.function.name + if tc.function and tc.function.arguments: + entry["arguments"] += tc.function.arguments + finally: + # Release the streaming httpx connection back to the pool on every + # exit path (normal completion, break, exception). openai.AsyncStream + # does not auto-close when the async-for loop exits early. + try: + await response.close() + except Exception: + pass + + finally: + # Close open blocks on both normal and exception paths so the + # frontend always sees matched start/end pairs. An exception mid + # ``async for chunk in response`` would otherwise leave reasoning + # and/or text unterminated and only ``StreamFinishStep`` emitted — + # the Reasoning / Text collapses would never finalise. + state.pending_events.extend(state.reasoning_emitter.close()) + # Flush any buffered text held back by the thinking stripper. + tail = state.thinking_stripper.flush() + if tail: + if not state.text_started: + state.pending_events.append(StreamTextStart(id=state.text_block_id)) + state.text_started = True + round_text += tail + state.pending_events.append( + StreamTextDelta(id=state.text_block_id, delta=tail) + ) if state.text_started: state.pending_events.append(StreamTextEnd(id=state.text_block_id)) state.text_started = False state.text_block_id = str(uuid.uuid4()) - finally: # Always persist partial text so the session history stays consistent, # even when the stream is interrupted by an exception. state.assistant_text += round_text @@ -278,17 +802,17 @@ async def _baseline_tool_executor( ) -def _baseline_conversation_updater( +def _mutate_openai_messages( messages: list[dict[str, Any]], response: LLMLoopResponse, - tool_results: list[ToolCallResult] | None = None, + tool_results: list[ToolCallResult] | None, ) -> None: - """Update OpenAI message list with assistant response + tool results. + """Append assistant / tool-result entries to the OpenAI message list. - Extracted from ``stream_chat_completion_baseline`` for readability. + This is the side-effect boundary for the next LLM call — no transcript + mutation happens here. """ if tool_results: - # Build assistant message with tool_calls assistant_msg: dict[str, Any] = {"role": "assistant"} if response.response_text: assistant_msg["content"] = response.response_text @@ -309,25 +833,120 @@ def _baseline_conversation_updater( "content": tr.content, } ) - else: - if response.response_text: - messages.append({"role": "assistant", "content": response.response_text}) + elif response.response_text: + messages.append({"role": "assistant", "content": response.response_text}) -async def _update_title_async( - session_id: str, message: str, user_id: str | None +def _record_turn_to_transcript( + response: LLMLoopResponse, + tool_results: list[ToolCallResult] | None, + *, + transcript_builder: TranscriptBuilder, + model: str, ) -> None: - """Generate and persist a session title in the background.""" - try: - title = await _generate_session_title(message, user_id, session_id) - if title and user_id: - await update_session_title(session_id, user_id, title, only_if_empty=True) - except Exception as e: - logger.warning("[Baseline] Failed to update session title: %s", e) + """Append assistant + tool-result entries to the transcript builder. + + Kept separate from :func:`_mutate_openai_messages` so the two + concerns (next-LLM-call payload vs. durable conversation log) can + evolve independently. + """ + if tool_results: + content_blocks: list[dict[str, Any]] = [] + if response.response_text: + content_blocks.append({"type": "text", "text": response.response_text}) + for tc in response.tool_calls: + try: + args = orjson.loads(tc.arguments) if tc.arguments else {} + except (ValueError, TypeError, orjson.JSONDecodeError) as parse_err: + logger.debug( + "[Baseline] Failed to parse tool_call arguments " + "(tool=%s, id=%s): %s", + tc.name, + tc.id, + parse_err, + ) + args = {} + content_blocks.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": args, + } + ) + if content_blocks: + transcript_builder.append_assistant( + content_blocks=content_blocks, + model=model, + stop_reason=STOP_REASON_TOOL_USE, + ) + for tr in tool_results: + # Record tool result to transcript AFTER the assistant tool_use + # block to maintain correct Anthropic API ordering: + # assistant(tool_use) → user(tool_result) + transcript_builder.append_tool_result( + tool_use_id=tr.tool_call_id, + content=tr.content, + ) + elif response.response_text: + transcript_builder.append_assistant( + content_blocks=[{"type": "text", "text": response.response_text}], + model=model, + stop_reason=STOP_REASON_END_TURN, + ) + + +def _baseline_conversation_updater( + messages: list[dict[str, Any]], + response: LLMLoopResponse, + tool_results: list[ToolCallResult] | None = None, + *, + transcript_builder: TranscriptBuilder, + model: str = "", + state: _BaselineStreamState | None = None, +) -> None: + """Update OpenAI message list with assistant response + tool results. + + Also records structured ChatMessage entries in ``state.session_messages`` + so the full tool-call history is persisted to the session (not just the + concatenated assistant text). + """ + _mutate_openai_messages(messages, response, tool_results) + _record_turn_to_transcript( + response, + tool_results, + transcript_builder=transcript_builder, + model=model, + ) + # Record structured messages for session persistence so tool calls + # and tool results survive across turns and mode switches. + if state is not None and tool_results: + assistant_msg = ChatMessage( + role="assistant", + content=response.response_text or "", + tool_calls=[ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments}, + } + for tc in response.tool_calls + ], + ) + state.session_messages.append(assistant_msg) + for tr in tool_results: + state.session_messages.append( + ChatMessage( + role="tool", + content=tr.content, + tool_call_id=tr.tool_call_id, + ) + ) async def _compress_session_messages( messages: list[ChatMessage], + model: str, ) -> list[ChatMessage]: """Compress session messages if they exceed the model's token limit. @@ -340,45 +959,257 @@ async def _compress_session_messages( msg_dict: dict[str, Any] = {"role": msg.role} if msg.content: msg_dict["content"] = msg.content + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id messages_dict.append(msg_dict) try: result = await compress_context( messages=messages_dict, - model=config.model, + model=model, client=_get_openai_client(), ) except Exception as e: logger.warning("[Baseline] Context compression with LLM failed: %s", e) result = await compress_context( messages=messages_dict, - model=config.model, + model=model, client=None, ) if result.was_compacted: logger.info( - "[Baseline] Context compacted: %d -> %d tokens " - "(%d summarized, %d dropped)", + "[Baseline] Context compacted: %d -> %d tokens (%d summarized, %d dropped)", result.original_token_count, result.token_count, result.messages_summarized, result.messages_dropped, ) return [ - ChatMessage(role=m["role"], content=m.get("content")) + ChatMessage( + role=m["role"], + content=m.get("content"), + tool_calls=m.get("tool_calls"), + tool_call_id=m.get("tool_call_id"), + ) for m in result.messages ] return messages +def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool: + """Return ``True`` when the caller should upload the final transcript. + + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. + """ + return bool(user_id) and upload_safe + + +def _append_gap_to_builder( + gap: list[ChatMessage], + builder: TranscriptBuilder, +) -> None: + """Append gap messages from chat-db into the TranscriptBuilder. + + Converts ChatMessage (OpenAI format) to TranscriptBuilder entries + (Claude CLI JSONL format) so the uploaded transcript covers all turns. + + Pre-condition: ``gap`` always starts at a user or assistant boundary + (never mid-turn at a ``tool`` role), because ``detect_gap`` enforces + ``session_messages[wm-1].role == 'assistant'`` before returning a non-empty + gap. Any ``tool`` role messages within the gap always follow an assistant + entry that already exists in the builder or in the gap itself. + """ + for msg in gap: + if msg.role == "user": + builder.append_user(msg.content or "") + elif msg.role == "assistant": + content_blocks: list[dict] = [] + if msg.content: + content_blocks.append({"type": "text", "text": msg.content}) + if msg.tool_calls: + for tc in msg.tool_calls: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + input_data = util_json.loads(fn.get("arguments", "{}"), fallback={}) + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", "") if isinstance(tc, dict) else "", + "name": fn.get("name", "unknown"), + "input": input_data, + } + ) + if not content_blocks: + # Fallback: ensure every assistant gap message produces an entry + # so the builder's entry count matches the gap length. + content_blocks.append({"type": "text", "text": ""}) + builder.append_assistant(content_blocks=content_blocks) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning( + "[Baseline] Skipping tool gap message with no tool_call_id" + ) + + +async def _load_prior_transcript( + user_id: str, + session_id: str, + session_messages: list[ChatMessage], + transcript_builder: TranscriptBuilder, +) -> tuple[bool, "TranscriptDownload | None"]: + """Download and load the prior CLI session into ``transcript_builder``. + + Returns a tuple of (upload_safe, transcript_download): + - ``upload_safe`` is ``True`` when it is safe to upload at the end of this + turn. Upload is suppressed only for **download errors** (unknown GCS + state) — missing and invalid files return ``True`` because there is + nothing in GCS worth protecting against overwriting. + - ``transcript_download`` is a ``TranscriptDownload`` with str content + (pre-decoded and stripped) when available, or ``None`` when no valid + transcript could be loaded. Callers pass this to + ``extract_context_messages`` to build the LLM context. + """ + try: + restore = await download_transcript( + user_id, session_id, log_prefix="[Baseline]" + ) + except Exception as e: + logger.warning("[Baseline] Session restore failed: %s", e) + # Unknown GCS state — be conservative, skip upload. + return False, None + + if restore is None: + logger.debug("[Baseline] No CLI session available — will upload fresh") + # Nothing in GCS to protect; allow upload so the first baseline turn + # writes the initial transcript snapshot. + return True, None + + content_bytes = restore.content + try: + raw_str = ( + content_bytes.decode("utf-8") + if isinstance(content_bytes, bytes) + else content_bytes + ) + except UnicodeDecodeError: + logger.warning("[Baseline] CLI session content is not valid UTF-8") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + stripped = strip_for_upload(raw_str) + if not validate_transcript(stripped): + logger.warning("[Baseline] CLI session content invalid after strip") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + transcript_builder.load_previous(stripped, log_prefix="[Baseline]") + logger.info( + "[Baseline] Loaded CLI session: %dB, msg_count=%d", + len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str), + restore.message_count, + ) + + gap = detect_gap(restore, session_messages) + if gap: + _append_gap_to_builder(gap, transcript_builder) + logger.info( + "[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB", + restore.message_count, + len(gap), + ) + + # Return a str-content version so extract_context_messages receives a + # pre-decoded, stripped transcript (avoids redundant decode + strip). + # TranscriptDownload.content is typed as bytes | str; we pass str here + # to avoid a redundant encode + decode round-trip. + str_restore = TranscriptDownload( + content=stripped, + message_count=restore.message_count, + mode=restore.mode, + ) + return True, str_restore + + +async def _upload_final_transcript( + user_id: str, + session_id: str, + transcript_builder: TranscriptBuilder, + session_msg_count: int, +) -> None: + """Serialize and upload the transcript for next-turn continuity. + + Uses the builder's own invariants to decide whether to upload, + avoiding a JSONL re-parse. A builder that ends with an assistant + entry is structurally complete; a builder that doesn't (empty, or + ends mid-turn) is skipped. + """ + try: + if transcript_builder.last_entry_type != "assistant": + logger.debug( + "[Baseline] No complete assistant turn to upload (last_entry=%s)", + transcript_builder.last_entry_type, + ) + return + content = transcript_builder.to_jsonl() + if not content: + logger.debug("[Baseline] Empty transcript content, skipping upload") + return + # Track the upload as a background task so a timeout doesn't leak an + # orphaned coroutine; shield it so cancellation of this caller doesn't + # abort the in-flight GCS write. + upload_task = asyncio.create_task( + upload_transcript( + user_id=user_id, + session_id=session_id, + content=content.encode("utf-8"), + message_count=session_msg_count, + mode="baseline", + log_prefix="[Baseline]", + ) + ) + _background_tasks.add(upload_task) + upload_task.add_done_callback(_background_tasks.discard) + # Bound the wait: a hung storage backend must not block the response + # from finishing. The task keeps running in _background_tasks on + # timeout and will be cleaned up when it resolves. + await asyncio.wait_for( + asyncio.shield(upload_task), timeout=_TRANSCRIPT_UPLOAD_TIMEOUT_S + ) + except asyncio.TimeoutError: + # Upload is still running in _background_tasks; we just stopped waiting. + logger.info( + "[Baseline] Transcript upload exceeded %ss wait — continuing as background task", + _TRANSCRIPT_UPLOAD_TIMEOUT_S, + ) + except Exception as upload_err: + logger.error("[Baseline] Transcript upload failed: %s", upload_err) + + async def stream_chat_completion_baseline( session_id: str, message: str | None = None, is_user_message: bool = True, user_id: str | None = None, session: ChatSession | None = None, + file_ids: list[str] | None = None, + permissions: "CopilotPermissions | None" = None, + context: dict[str, str] | None = None, + mode: CopilotMode | None = None, + model: CopilotLlmModel | None = None, + request_arrival_at: float = 0.0, **_kwargs: Any, ) -> AsyncGenerator[StreamBaseResponse, None]: """Baseline LLM with tool calling via OpenAI-compatible API. @@ -397,24 +1228,151 @@ async def stream_chat_completion_baseline( f"Session {session_id} not found. Please create a new session first." ) - # Append user message - new_role = "user" if is_user_message else "assistant" - if message and ( - len(session.messages) == 0 - or not ( - session.messages[-1].role == new_role - and session.messages[-1].content == message - ) - ): - session.messages.append(ChatMessage(role=new_role, content=message)) + # Drop orphan tool_use + trailing stop-marker rows left by a previous + # Stop mid-tool-call so the new turn starts from a well-formed message list. + prune_orphan_tool_calls( + session.messages, log_prefix=f"[Baseline] [{session_id[:12]}]" + ) + + # Strip any user-injected tags on every turn. + # Only the server-injected prefix on the first message is trusted. + if message: + message = strip_user_context_tags(message) + + if maybe_append_user_message(session, message, is_user_message): if is_user_message: track_user_message( user_id=user_id, session_id=session_id, - message_length=len(message), + message_length=len(message or ""), ) - session = await upsert_chat_session(session) + # Capture count *before* the pending drain so is_first_turn and the + # transcript staleness check are not skewed by queued messages. + _pre_drain_msg_count = len(session.messages) + + # Drain any messages the user queued via POST /messages/pending + # while this session was idle (or during a previous turn whose + # mid-loop drains missed them). + # The drained content is appended after ``message`` so the user's submitted + # message remains the leading context (better UX: the user sent their primary + # message first, queued follow-ups second). The already-saved user message + # in the DB is updated via update_message_content_by_sequence rather than + # inserting a new row, because routes.py has already saved the user message + # before the executor picks up the turn (using insert_pending_before_last + + # persist_session_safe would add a duplicate row at sequence N+1). + drained_at_start_pending = await drain_pending_safe(session_id, "[Baseline]") + if drained_at_start_pending: + logger.info( + "[Baseline] Draining %d pending message(s) at turn start for session %s", + len(drained_at_start_pending), + session_id, + ) + # Chronological combine: pending typed BEFORE this /stream + # request's arrival go ahead of ``message``; race-path follow-ups + # typed AFTER (queued while /stream was still processing) go + # after. See ``combine_pending_with_current`` for details. + message = combine_pending_with_current( + drained_at_start_pending, + message, + request_arrival_at=request_arrival_at, + ) + # Update the in-memory content of the already-saved user message + # and persist that update by sequence number. + last_user_msg = next( + (m for m in reversed(session.messages) if m.role == "user"), None + ) + if last_user_msg is None or last_user_msg.sequence is None: + # Defensive: routes.py always pre-saves the user message with a + # sequence before dispatch, so this is unreachable under normal + # flow. Raising instead of a warning-and-continue avoids silent + # data loss (in-memory message diverges from the DB row, so the + # queued chip would disappear from the UI after refresh without + # a corresponding bubble). + raise RuntimeError( + f"[Baseline] Cannot persist turn-start pending injection: " + f"last_user_msg={'missing' if last_user_msg is None else 'has no sequence'}" + ) + last_user_msg.content = message + await chat_db().update_message_content_by_sequence( + session_id, last_user_msg.sequence, message + ) + + # Select model based on the per-request tier toggle (standard / advanced). + # The path (fast vs extended_thinking) is already decided — we're in the + # baseline (fast) path; ``mode`` is accepted for logging parity only. + active_model = _resolve_baseline_model(model) + + # --- E2B sandbox setup (feature parity with SDK path) --- + e2b_sandbox = None + e2b_api_key = config.active_e2b_api_key + if e2b_api_key: + try: + from backend.copilot.tools.e2b_sandbox import get_or_create_sandbox + + e2b_sandbox = await get_or_create_sandbox( + session_id, + api_key=e2b_api_key, + template=config.e2b_sandbox_template, + timeout=config.e2b_sandbox_timeout, + on_timeout=config.e2b_sandbox_on_timeout, + ) + except Exception: + logger.warning("[Baseline] E2B sandbox setup failed", exc_info=True) + + # --- Transcript support (feature parity with SDK path) --- + transcript_builder = TranscriptBuilder() + transcript_upload_safe = True + + # Build system prompt only on the first turn to avoid mid-conversation + # changes from concurrent chats updating business understanding. + # Use the pre-drain count so queued pending messages don't incorrectly + # flip is_first_turn to False on an actual first turn. + is_first_turn = _pre_drain_msg_count <= 1 + # Gate context fetch on both first turn AND user message so that assistant- + # role calls (e.g. tool-result submissions) on the first turn don't trigger + # a needless DB lookup for user understanding. + should_inject_user_context = is_first_turn and is_user_message + + if should_inject_user_context: + prompt_task = _build_system_prompt(user_id) + else: + prompt_task = _build_system_prompt(None) + + # Run download + prompt build concurrently — both are independent I/O + # on the request critical path. Use the pre-drain count so pending + # messages drained at turn start don't spuriously trigger a transcript + # load on an actual first turn. + transcript_download: TranscriptDownload | None = None + if user_id and _pre_drain_msg_count > 1: + ( + (transcript_upload_safe, transcript_download), + (base_system_prompt, understanding), + ) = await asyncio.gather( + _load_prior_transcript( + user_id=user_id, + session_id=session_id, + session_messages=session.messages, + transcript_builder=transcript_builder, + ), + prompt_task, + ) + else: + base_system_prompt, understanding = await prompt_task + + # Append user message to transcript after context injection below so the + # transcript receives the prefixed message when user context is available. + + # NOTE: drained pending messages are folded into the current user + # message's content (see the turn-start drain above), so the single + # ``transcript_builder.append_user`` call below (covered by the + # ``if message and is_user_message`` branch that appends + # ``user_message_for_transcript or message``) already records the + # combined text in the transcript. Do NOT also append drained items + # individually here — on the ``transcript_download is None`` path + # that would produce N separate pending entries plus the combined + # entry, duplicating the pending content in the JSONL uploaded for + # the next turn's ``--resume``. # Generate title for new sessions if is_user_message and not session.title: @@ -430,36 +1388,211 @@ async def stream_chat_completion_baseline( message_id = str(uuid.uuid4()) - # Build system prompt only on the first turn to avoid mid-conversation - # changes from concurrent chats updating business understanding. - is_first_turn = len(session.messages) <= 1 - if is_first_turn: - base_system_prompt, _ = await _build_system_prompt( - user_id, has_conversation_history=False - ) - else: - base_system_prompt, _ = await _build_system_prompt( - user_id=None, has_conversation_history=True - ) + # Append tool documentation, technical notes, and Graphiti memory instructions + graphiti_enabled = await is_enabled_for_user(user_id) - # Append tool documentation and technical notes - system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" + # Append the builder-session block (graph id+name + full building guide) + # AFTER the shared supplements so the system prompt is byte-identical + # across turns of the same builder session — Claude's prompt cache keeps + # the ~20KB guide warm for the whole session. Empty string for + # non-builder sessions keeps the cross-user cache hot. + builder_session_suffix = await build_builder_system_prompt_suffix(session) + system_prompt = ( + base_system_prompt + + SHARED_TOOL_NOTES + + graphiti_supplement + + builder_session_suffix + ) - # Compress context if approaching the model's token limit - messages_for_context = await _compress_session_messages(session.messages) + # Warm context: pre-load relevant facts from Graphiti on first turn. + # Use the pre-drain count so pending messages drained at turn start + # don't prevent warm context injection on an actual first turn. + # Stored here but injected into the user message (not the system prompt) + # after openai_messages is built — keeps system prompt static for caching. + warm_ctx: str | None = None + if graphiti_enabled and user_id and _pre_drain_msg_count <= 1: + from backend.copilot.graphiti.context import fetch_warm_context - # Build OpenAI message list from session history + warm_ctx = await fetch_warm_context(user_id, message or "") + + # Context path: transcript content (compacted, isCompactSummary preserved) + + # gap (DB messages after watermark) + current user turn. + # This avoids re-reading the full session history from DB on every turn. + # See extract_context_messages() in transcript.py for the shared primitive. + prior_context = extract_context_messages(transcript_download, session.messages) + messages_for_context = await _compress_session_messages( + prior_context + ([session.messages[-1]] if session.messages else []), + model=active_model, + ) + + # Build OpenAI message list from session history. + # Include tool_calls on assistant messages and tool-role results so the + # model retains full context of what tools were invoked and their outcomes. openai_messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt} ] for msg in messages_for_context: - if msg.role in ("user", "assistant") and msg.content: + if msg.role == "assistant": + entry: dict[str, Any] = {"role": "assistant"} + if msg.content: + entry["content"] = msg.content + if msg.tool_calls: + entry["tool_calls"] = msg.tool_calls + if msg.content or msg.tool_calls: + openai_messages.append(entry) + elif msg.role == "tool" and msg.tool_call_id: + openai_messages.append( + { + "role": "tool", + "tool_call_id": msg.tool_call_id, + "content": msg.content or "", + } + ) + elif msg.role == "user" and msg.content: openai_messages.append({"role": msg.role, "content": msg.content}) + # Inject user context into the first user message on first turn. + # Done before attachment/URL injection so the context prefix lands at + # the very start of the message content. + user_message_for_transcript = message + if should_inject_user_context: + prefixed = await inject_user_context( + understanding, message or "", session_id, session.messages + ) + if prefixed is not None: + # Reverse scan so we update the current turn's user message, not + # the first (oldest) one when pending messages were drained. + for msg in reversed(openai_messages): + if msg["role"] == "user": + msg["content"] = prefixed + break + user_message_for_transcript = prefixed + else: + logger.warning("[Baseline] No user message found for context injection") + + # Inject Graphiti warm context into the current turn's user message (not + # the system prompt) so the system prompt stays static and cacheable. + # warm_ctx is already wrapped in . + # Appended AFTER user_context so stays at the very start. + # Reverse scan so we update the current turn's user message, not the + # oldest one when pending messages were drained. + if warm_ctx: + for msg in reversed(openai_messages): + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = f"{existing}\n\n{warm_ctx}" + break + # Do NOT append warm_ctx to user_message_for_transcript — it would + # persist stale temporal context into the transcript for future turns. + + # Inject the per-turn ```` prefix when the session is + # bound to a graph via ``metadata.builder_graph_id``. Runs on every + # user turn (not just the first) so the LLM always sees the live graph + # snapshot — if the user edits the graph between turns, the next turn + # carries the updated nodes/links. Only version + nodes + links here; + # the static guide + graph id live in the system prompt via + # ``build_builder_system_prompt_suffix`` (session-stable, prompt-cached). + # Prepended AFTER any // blocks + # — same trust tier as those server-injected prefixes. Not persisted to + # the transcript: the snapshot is stale-by-definition after the turn ends. + if is_user_message and session.metadata.builder_graph_id: + builder_block = await build_builder_context_turn_prefix(session, user_id) + if builder_block: + for msg in reversed(openai_messages): + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = builder_block + existing + break + + # Append user message to transcript. + # Always append when the message is present and is from the user, + # even on duplicate-suppressed retries (is_new_message=False). + # The loaded transcript may be stale (uploaded before the previous + # attempt stored this message), so skipping it would leave the + # transcript without the user turn, creating a malformed + # assistant-after-assistant structure when the LLM reply is added. + if message and is_user_message: + transcript_builder.append_user(content=user_message_for_transcript or message) + + # --- File attachments (feature parity with SDK path) --- + working_dir: str | None = None + attachment_hint = "" + image_blocks: list[dict[str, Any]] = [] + if file_ids and user_id: + working_dir = tempfile.mkdtemp(prefix=f"copilot-baseline-{session_id[:8]}-") + attachment_hint, image_blocks = await _prepare_baseline_attachments( + file_ids, user_id, session_id, working_dir + ) + + # --- URL context --- + context_hint = "" + if context and context.get("url"): + url = context["url"] + content_text = context.get("content", "") + if content_text: + context_hint = ( + f"\n[The user shared a URL: {url}\nContent:\n{content_text[:8000]}]" + ) + else: + context_hint = f"\n[The user shared a URL: {url}]" + + # Append attachment + context hints and image blocks to the last user + # message in a single reverse scan. + extra_hint = attachment_hint + context_hint + if extra_hint or image_blocks: + for i in range(len(openai_messages) - 1, -1, -1): + if openai_messages[i].get("role") == "user": + existing = openai_messages[i].get("content", "") + if isinstance(existing, str): + text = existing + "\n" + extra_hint if extra_hint else existing + if image_blocks: + parts: list[dict[str, Any]] = [{"type": "text", "text": text}] + for img in image_blocks: + parts.append( + { + "type": "image_url", + "image_url": { + "url": ( + f"data:{img['source']['media_type']};" + f"base64,{img['source']['data']}" + ) + }, + } + ) + openai_messages[i]["content"] = parts + else: + openai_messages[i]["content"] = text + break + tools = get_available_tools() + # --- Permission filtering --- + if permissions is not None: + tools = _filter_tools_by_permissions(tools, permissions) + + # Pre-mark cache_control on the last tool schema once per session. The + # tool set is static within a request, so doing this here (instead of in + # _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM + # round of the tool-call loop. + # + # Only apply to Anthropic routes — OpenAI/Grok/other providers would + # 400 on the unknown ``cache_control`` field inside tool definitions. + if _is_anthropic_model(active_model): + tools = cast( + list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools) + ) + # Propagate execution context so tool handlers can read session-level flags. - set_execution_context(user_id, session) + set_execution_context( + user_id, + session, + sandbox=e2b_sandbox, + sdk_cwd=working_dir, + permissions=permissions, + ) yield StreamStart(messageId=message_id, sessionId=session_id) @@ -478,13 +1611,37 @@ async def stream_chat_completion_baseline( logger.warning("[Baseline] Langfuse trace context setup failed") _stream_error = False # Track whether an error occurred during streaming - state = _BaselineStreamState() + state = _BaselineStreamState(model=active_model) # Bind extracted module-level callbacks to this request's state/session # using functools.partial so they satisfy the Protocol signatures. _bound_llm_caller = partial(_baseline_llm_caller, state=state) - _bound_tool_executor = partial( - _baseline_tool_executor, state=state, user_id=user_id, session=session + + # ``session`` is reassigned after each mid-turn ``persist_session_safe`` + # call (``upsert_chat_session`` returns a fresh ``model_copy``). Holding + # the object via ``partial(session=session)`` would pin tool executions + # to the *original* object — any post-persist ``session.successful_agent_runs`` + # mutation from a run_agent tool call would then land on the stale copy + # and be lost on the final persist. Wrap in a 1-element holder and read + # the current binding lazily so the executor always sees the latest session. + _session_holder: list[ChatSession] = [session] + + async def _bound_tool_executor( + tool_call: LLMToolCall, tools: Sequence[Any] + ) -> ToolCallResult: + return await _baseline_tool_executor( + tool_call, + tools, + state=state, + user_id=user_id, + session=_session_holder[0], + ) + + _bound_conversation_updater = partial( + _baseline_conversation_updater, + transcript_builder=transcript_builder, + model=active_model, + state=state, ) try: @@ -494,7 +1651,7 @@ async def stream_chat_completion_baseline( tools=tools, llm_call=_bound_llm_caller, execute_tool=_bound_tool_executor, - update_conversation=_baseline_conversation_updater, + update_conversation=_bound_conversation_updater, max_iterations=_MAX_TOOL_ROUNDS, ): # Drain buffered events after each iteration (real-time streaming) @@ -502,6 +1659,124 @@ async def stream_chat_completion_baseline( yield evt state.pending_events.clear() + # Inject any messages the user queued while the turn was + # running. ``tool_call_loop`` mutates ``openai_messages`` + # in-place, so appending here means the model sees the new + # messages on its next LLM call. + # + # IMPORTANT: skip when the loop has already finished (no + # more LLM calls are coming). ``tool_call_loop`` yields + # a final ``ToolCallLoopResult`` on both paths: + # - natural finish: ``finished_naturally=True`` + # - hit max_iterations: ``finished_naturally=False`` + # and ``iterations >= max_iterations`` + # In either case the loop is about to return on the next + # ``async for`` step, so draining here would silently + # lose the message (the user sees 202 but the model never + # reads the text). Those messages stay in the buffer and + # get picked up at the start of the next turn. + is_final_yield = ( + loop_result.finished_naturally + or loop_result.iterations >= _MAX_TOOL_ROUNDS + ) + if is_final_yield: + continue + try: + pending = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "[Baseline] mid-loop drain_pending_messages failed for session %s", + session_id, + exc_info=True, + ) + pending = [] + if pending: + # Flush any buffered assistant/tool messages from completed + # rounds into session.messages BEFORE appending the pending + # user message. ``_baseline_conversation_updater`` only + # records assistant+tool rounds into ``state.session_messages`` + # — they are normally batch-flushed in the finally block. + # Without this in-order flush, the mid-loop pending user + # message lands before the preceding round's assistant/tool + # entries, producing chronologically-wrong session.messages + # on persist (user interposed between an assistant tool_call + # and its tool-result), which breaks OpenAI tool-call ordering + # invariants on the next turn's replay. + # + # Also persist any assistant text from text-only rounds (rounds + # with no tool calls, which ``_baseline_conversation_updater`` + # does NOT record in session_messages). If we only update + # ``_flushed_assistant_text_len`` without persisting the text, + # that text is silently lost: the finally block only appends + # assistant_text[_flushed_assistant_text_len:], so text generated + # before this drain never reaches session.messages. + recorded_text = "".join( + m.content or "" + for m in state.session_messages + if m.role == "assistant" + ) + unflushed_text = state.assistant_text[ + state._flushed_assistant_text_len : + ] + text_only_text = ( + unflushed_text[len(recorded_text) :] + if unflushed_text.startswith(recorded_text) + else unflushed_text + ) + if text_only_text.strip(): + session.messages.append( + ChatMessage(role="assistant", content=text_only_text) + ) + for _buffered in state.session_messages: + session.messages.append(_buffered) + state.session_messages.clear() + # Record how much assistant_text has been covered by the + # structured entries just flushed, so the finally block's + # final-text dedup doesn't re-append rounds already persisted. + state._flushed_assistant_text_len = len(state.assistant_text) + + # Persist the assistant/tool flush BEFORE the pending append + # so a later pending-persist failure can roll back the + # pending rows without also discarding LLM output. + session = await persist_session_safe(session, "[Baseline]") + # ``upsert_chat_session`` may return a *new* ``ChatSession`` + # instance (e.g. when a concurrent title update has written a + # newer title to Redis, it returns ``session.model_copy``). + # Keep ``_session_holder`` in sync so subsequent tool rounds + # executed via ``_bound_tool_executor`` see the fresh session + # — any tool-side mutations on the stale object would be + # discarded when the new one is persisted in the ``finally``. + _session_holder[0] = session + + # ``format_pending_as_user_message`` embeds file attachments + # and context URL/page content into the content string so + # the in-session transcript is a faithful copy of what the + # model actually saw. We also mirror each push into + # ``openai_messages`` so the model's next LLM round sees it. + # + # Pre-compute the formatted dicts once so both the openai + # messages append and the content_of lookup inside the + # shared helper use the same string — and so ``on_rollback`` + # can trim ``openai_messages`` to the recorded anchor. + formatted_by_pm = { + id(pm): format_pending_as_user_message(pm) for pm in pending + } + _openai_anchor = len(openai_messages) + for pm in pending: + openai_messages.append(formatted_by_pm[id(pm)]) + + def _trim_openai_on_rollback(_session_anchor: int) -> None: + del openai_messages[_openai_anchor:] + + await persist_pending_as_user_rows( + session, + transcript_builder, + pending, + log_prefix="[Baseline]", + content_of=lambda pm: formatted_by_pm[id(pm)]["content"], + on_rollback=_trim_openai_on_rollback, + ) + if loop_result and not loop_result.finished_naturally: limit_msg = ( f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds " @@ -517,33 +1792,41 @@ async def stream_chat_completion_baseline( _stream_error = True error_msg = str(e) or type(e).__name__ logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True) - # Close any open text block. The llm_caller's finally block - # already appended StreamFinishStep to pending_events, so we must - # insert StreamTextEnd *before* StreamFinishStep to preserve the - # protocol ordering: - # StreamStartStep -> StreamTextStart -> ...deltas... -> + # ``_baseline_llm_caller``'s finally block closes any open + # reasoning / text blocks and appends ``StreamFinishStep`` on + # both normal and exception paths, so pending_events already has + # the correct protocol ordering: + # StreamStartStep -> StreamReasoningStart -> ...deltas... -> + # StreamReasoningEnd -> StreamTextStart -> ...deltas... -> # StreamTextEnd -> StreamFinishStep - # Appending (or yielding directly) would place it after - # StreamFinishStep, violating the protocol. - if state.text_started: - # Find the last StreamFinishStep and insert before it. - insert_pos = len(state.pending_events) - for i in range(len(state.pending_events) - 1, -1, -1): - if isinstance(state.pending_events[i], StreamFinishStep): - insert_pos = i - break - state.pending_events.insert( - insert_pos, StreamTextEnd(id=state.text_block_id) - ) - # Drain pending events in correct order + # Just drain what's buffered, then yield the error. for evt in state.pending_events: yield evt state.pending_events.clear() yield StreamError(errorText=error_msg, code="baseline_error") # Still persist whatever we got finally: - # Close Langfuse trace context + # Pending messages are drained atomically at turn start and + # between tool rounds, so there's nothing to clear in finally. + # Any message pushed after the final drain window stays in the + # buffer and gets picked up at the start of the next turn. + + # Set cost attributes on OTEL span before closing if _trace_ctx is not None: + try: + span = otel_trace.get_current_span() + if span and span.is_recording(): + span.set_attribute( + "gen_ai.usage.prompt_tokens", state.turn_prompt_tokens + ) + span.set_attribute( + "gen_ai.usage.completion_tokens", + state.turn_completion_tokens, + ) + if state.cost_usd is not None: + span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd) + except Exception: + logger.debug("[Baseline] Failed to set OTEL cost attributes") try: _trace_ctx.__exit__(None, None, None) except Exception: @@ -563,10 +1846,10 @@ async def stream_chat_completion_baseline( and not (_stream_error and not state.assistant_text) ): state.turn_prompt_tokens = max( - estimate_token_count(openai_messages, model=config.model), 1 + estimate_token_count(openai_messages, model=active_model), 1 ) state.turn_completion_tokens = estimate_token_count_str( - state.assistant_text, model=config.model + state.assistant_text, model=active_model ) logger.info( "[Baseline] No streaming usage reported; estimated tokens: " @@ -574,39 +1857,112 @@ async def stream_chat_completion_baseline( state.turn_prompt_tokens, state.turn_completion_tokens, ) - # Persist token usage to session and record for rate limiting. - # NOTE: OpenRouter folds cached tokens into prompt_tokens, so we - # cannot break out cache_read/cache_creation weights. Users on the - # baseline path may be slightly over-counted vs the SDK path. + # When prompt_tokens_details.cached_tokens is reported, subtract + # them from prompt_tokens to get the uncached count so the cost + # breakdown stays accurate. + uncached_prompt = state.turn_prompt_tokens + if state.turn_cache_read_tokens > 0: + uncached_prompt = max( + 0, state.turn_prompt_tokens - state.turn_cache_read_tokens + ) await persist_and_record_usage( session=session, user_id=user_id, - prompt_tokens=state.turn_prompt_tokens, + prompt_tokens=uncached_prompt, completion_tokens=state.turn_completion_tokens, + cache_read_tokens=state.turn_cache_read_tokens, + cache_creation_tokens=state.turn_cache_creation_tokens, log_prefix="[Baseline]", + cost_usd=state.cost_usd, + model=active_model, ) - # Persist assistant response - if state.assistant_text: - session.messages.append( - ChatMessage(role="assistant", content=state.assistant_text) + # Persist structured tool-call history (assistant + tool messages) + # collected by the conversation updater, then the final text response. + for msg in state.session_messages: + session.messages.append(msg) + # Append the final assistant text (from the last LLM call that had + # no tool calls, i.e. the natural finish). Only add it if the + # conversation updater didn't already record it as part of a + # tool-call round (which would have empty response_text). + # Only consider assistant text produced AFTER the last mid-loop + # flush. ``_flushed_assistant_text_len`` tracks the prefix already + # persisted via structured session_messages during mid-loop pending + # drains; including it here would duplicate those rounds. + final_text = state.assistant_text[state._flushed_assistant_text_len :] + if state.session_messages: + # Strip text already captured in tool-call round messages + recorded = "".join( + m.content or "" for m in state.session_messages if m.role == "assistant" ) + if final_text.startswith(recorded): + final_text = final_text[len(recorded) :] + if final_text.strip(): + session.messages.append(ChatMessage(role="assistant", content=final_text)) try: await upsert_chat_session(session) except Exception as persist_err: logger.error("[Baseline] Failed to persist session: %s", persist_err) + # --- Graphiti: ingest conversation turn for temporal memory --- + if graphiti_enabled and user_id and message and is_user_message: + from backend.copilot.graphiti.ingest import enqueue_conversation_turn + + # Pass only the final assistant reply (after stripping tool-loop + # chatter) so derived-finding distillation sees the substantive + # response, not intermediate tool-planning text. + _ingest_task = asyncio.create_task( + enqueue_conversation_turn( + user_id, + session_id, + message, + assistant_msg=final_text if state else "", + ) + ) + _background_tasks.add(_ingest_task) + _ingest_task.add_done_callback(_background_tasks.discard) + + # --- Upload transcript for next-turn continuity --- + # Backfill partial assistant text that wasn't recorded by the + # conversation updater (e.g. when the stream aborted mid-round). + # Without this, mode-switching after a failed turn would lose + # the partial assistant response from the transcript. + if _stream_error and state.assistant_text: + if transcript_builder.last_entry_type != "assistant": + transcript_builder.append_assistant( + content_blocks=[{"type": "text", "text": state.assistant_text}], + model=active_model, + stop_reason=STOP_REASON_END_TURN, + ) + + if user_id and should_upload_transcript(user_id, transcript_upload_safe): + await _upload_final_transcript( + user_id=user_id, + session_id=session_id, + transcript_builder=transcript_builder, + session_msg_count=len(session.messages), + ) + + # Clean up the ephemeral working directory used for file attachments. + if working_dir is not None: + shutil.rmtree(working_dir, ignore_errors=True) + # Yield usage and finish AFTER try/finally (not inside finally). # PEP 525 prohibits yielding from finally in async generators during # aclose() — doing so raises RuntimeError on client disconnect. # On GeneratorExit the client is already gone, so unreachable yields # are harmless; on normal completion they reach the SSE stream. if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0: + # Report uncached prompt tokens to match what was billed — cached tokens + # are excluded so the frontend display is consistent with cost_usd. + billed_prompt = max(0, state.turn_prompt_tokens - state.turn_cache_read_tokens) yield StreamUsage( - prompt_tokens=state.turn_prompt_tokens, + prompt_tokens=billed_prompt, completion_tokens=state.turn_completion_tokens, - total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens, + total_tokens=billed_prompt + state.turn_completion_tokens, + cache_read_tokens=state.turn_cache_read_tokens, + cache_creation_tokens=state.turn_cache_creation_tokens, ) yield StreamFinish() diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py new file mode 100644 index 0000000000..03a9ef99c9 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -0,0 +1,1950 @@ +"""Unit tests for baseline service pure-logic helpers. + +These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`` +without requiring API keys, database connections, or network access. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai.types.chat import ChatCompletionToolParam + +from backend.copilot.baseline.service import ( + _baseline_conversation_updater, + _baseline_llm_caller, + _BaselineStreamState, + _build_cached_system_message, + _compress_session_messages, + _extract_cache_creation_tokens, + _fresh_anthropic_caching_headers, + _fresh_ephemeral_cache_control, + _is_anthropic_model, + _mark_system_message_with_cache_control, + _mark_tools_with_cache_control, +) +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, +) +from backend.copilot.transcript_builder import TranscriptBuilder +from backend.util.prompt import CompressResult +from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult + + +class TestBaselineStreamState: + def test_defaults(self): + state = _BaselineStreamState() + assert state.pending_events == [] + assert state.assistant_text == "" + assert state.text_started is False + assert state.turn_prompt_tokens == 0 + assert state.turn_completion_tokens == 0 + assert state.text_block_id # Should be a UUID string + + def test_mutable_fields(self): + state = _BaselineStreamState() + state.assistant_text = "hello" + state.turn_prompt_tokens = 100 + state.turn_completion_tokens = 50 + assert state.assistant_text == "hello" + assert state.turn_prompt_tokens == 100 + assert state.turn_completion_tokens == 50 + + +class TestBaselineConversationUpdater: + """Tests for _baseline_conversation_updater which updates the OpenAI + message list and transcript builder after each LLM call.""" + + def _make_transcript_builder(self) -> TranscriptBuilder: + builder = TranscriptBuilder() + builder.append_user("test question") + return builder + + def test_text_only_response(self): + """When the LLM returns text without tool calls, the updater appends + a single assistant message and records it in the transcript.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text="Hello, world!", + tool_calls=[], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + + _baseline_conversation_updater( + messages, + response, + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + + assert len(messages) == 1 + assert messages[0]["role"] == "assistant" + assert messages[0]["content"] == "Hello, world!" + # Transcript should have user + assistant + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + + def test_tool_calls_response(self): + """When the LLM returns tool calls, the updater appends the assistant + message with tool_calls and tool result messages.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text="Let me search...", + tool_calls=[ + LLMToolCall( + id="tc_1", + name="search", + arguments='{"query": "test"}', + ), + ], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult( + tool_call_id="tc_1", + tool_name="search", + content="Found result", + ), + ] + + _baseline_conversation_updater( + messages, + response, + tool_results=tool_results, + transcript_builder=builder, + model="test-model", + ) + + # Messages: assistant (with tool_calls) + tool result + assert len(messages) == 2 + assert messages[0]["role"] == "assistant" + assert messages[0]["content"] == "Let me search..." + assert len(messages[0]["tool_calls"]) == 1 + assert messages[0]["tool_calls"][0]["id"] == "tc_1" + assert messages[1]["role"] == "tool" + assert messages[1]["tool_call_id"] == "tc_1" + assert messages[1]["content"] == "Found result" + + # Transcript: user + assistant(tool_use) + user(tool_result) + assert builder.entry_count == 3 + + def test_tool_calls_without_text(self): + """Tool calls without accompanying text should still work.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text=None, + tool_calls=[ + LLMToolCall(id="tc_1", name="run", arguments="{}"), + ], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"), + ] + + _baseline_conversation_updater( + messages, + response, + tool_results=tool_results, + transcript_builder=builder, + model="test-model", + ) + + assert len(messages) == 2 + assert "content" not in messages[0] # No text content + assert messages[0]["tool_calls"][0]["function"]["name"] == "run" + + def test_no_text_no_tools(self): + """When the response has no text and no tool calls, nothing is appended.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text=None, + tool_calls=[], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + + _baseline_conversation_updater( + messages, + response, + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + + assert len(messages) == 0 + # Only the user entry from setup + assert builder.entry_count == 1 + + def test_multiple_tool_calls(self): + """Multiple tool calls in a single response are all recorded.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text=None, + tool_calls=[ + LLMToolCall(id="tc_1", name="tool_a", arguments="{}"), + LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'), + ], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"), + ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"), + ] + + _baseline_conversation_updater( + messages, + response, + tool_results=tool_results, + transcript_builder=builder, + model="test-model", + ) + + # 1 assistant + 2 tool results + assert len(messages) == 3 + assert len(messages[0]["tool_calls"]) == 2 + assert messages[1]["tool_call_id"] == "tc_1" + assert messages[2]["tool_call_id"] == "tc_2" + + def test_invalid_tool_arguments_handled(self): + """Tool call with invalid JSON arguments: the arguments field is + stored as-is in the message, and orjson failure falls back to {} + in the transcript content_blocks.""" + messages: list = [] + builder = self._make_transcript_builder() + response = LLMLoopResponse( + response_text=None, + tool_calls=[ + LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"), + ], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"), + ] + + _baseline_conversation_updater( + messages, + response, + tool_results=tool_results, + transcript_builder=builder, + model="test-model", + ) + + # Should not raise — invalid JSON falls back to {} in transcript + assert len(messages) == 2 + assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json" + + +class TestCompressSessionMessagesPreservesToolCalls: + """``_compress_session_messages`` must round-trip tool_calls + tool_call_id. + + Compression serialises ChatMessage to dict for ``compress_context`` and + reifies the result back to ChatMessage. A regression that drops + ``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message + list and break downstream tool-execution rounds. + """ + + @pytest.mark.asyncio + async def test_compressed_output_keeps_tool_calls_and_ids(self): + # Simulate compression that returns a summary + the most recent + # assistant(tool_call) + tool(tool_result) intact. + summary = {"role": "system", "content": "prior turns: user asked X"} + assistant_with_tc = { + "role": "assistant", + "content": "calling tool", + "tool_calls": [ + { + "id": "tc_abc", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"y"}'}, + } + ], + } + tool_result = { + "role": "tool", + "tool_call_id": "tc_abc", + "content": "search result", + } + + compress_result = CompressResult( + messages=[summary, assistant_with_tc, tool_result], + token_count=100, + was_compacted=True, + original_token_count=5000, + messages_summarized=10, + messages_dropped=0, + ) + + # Input: messages that should be compressed. + input_messages = [ + ChatMessage(role="user", content="q1"), + ChatMessage( + role="assistant", + content="calling tool", + tool_calls=[ + { + "id": "tc_abc", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q":"y"}', + }, + } + ], + ), + ChatMessage( + role="tool", + tool_call_id="tc_abc", + content="search result", + ), + ] + + with patch( + "backend.copilot.baseline.service.compress_context", + new=AsyncMock(return_value=compress_result), + ): + compressed = await _compress_session_messages( + input_messages, model="openrouter/anthropic/claude-opus-4" + ) + + # Summary, assistant(tool_calls), tool(tool_call_id). + assert len(compressed) == 3 + # Assistant message must keep its tool_calls intact. + assistant_msg = compressed[1] + assert assistant_msg.role == "assistant" + assert assistant_msg.tool_calls is not None + assert len(assistant_msg.tool_calls) == 1 + assert assistant_msg.tool_calls[0]["id"] == "tc_abc" + assert assistant_msg.tool_calls[0]["function"]["name"] == "search" + # Tool-role message must keep tool_call_id for OpenAI linkage. + tool_msg = compressed[2] + assert tool_msg.role == "tool" + assert tool_msg.tool_call_id == "tc_abc" + assert tool_msg.content == "search result" + + @pytest.mark.asyncio + async def test_uncompressed_passthrough_keeps_fields(self): + """When compression is a no-op (was_compacted=False), the original + messages must be returned unchanged — including tool_calls.""" + input_messages = [ + ChatMessage( + role="assistant", + content="c", + tool_calls=[ + { + "id": "t1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + ), + ChatMessage(role="tool", tool_call_id="t1", content="ok"), + ] + + noop_result = CompressResult( + messages=[], # ignored when was_compacted=False + token_count=10, + was_compacted=False, + ) + + with patch( + "backend.copilot.baseline.service.compress_context", + new=AsyncMock(return_value=noop_result), + ): + out = await _compress_session_messages( + input_messages, model="openrouter/anthropic/claude-opus-4" + ) + + assert out is input_messages # same list returned + assert out[0].tool_calls is not None + assert out[0].tool_calls[0]["id"] == "t1" + assert out[1].tool_call_id == "t1" + + +# ---- _filter_tools_by_permissions tests ---- # + + +def _make_tool(name: str) -> ChatCompletionToolParam: + """Build a minimal OpenAI ChatCompletionToolParam.""" + return ChatCompletionToolParam( + type="function", + function={"name": name, "parameters": {}}, + ) + + +class TestFilterToolsByPermissions: + """Tests for _filter_tools_by_permissions.""" + + @patch( + "backend.copilot.permissions.all_known_tool_names", + return_value=frozenset({"run_block", "web_fetch", "bash_exec"}), + ) + def test_empty_permissions_returns_all(self, _mock_names): + """Empty permissions (no filtering) returns every tool unchanged.""" + from backend.copilot.baseline.service import _filter_tools_by_permissions + from backend.copilot.permissions import CopilotPermissions + + tools = [_make_tool("run_block"), _make_tool("web_fetch")] + perms = CopilotPermissions() + result = _filter_tools_by_permissions(tools, perms) + assert result == tools + + @patch( + "backend.copilot.permissions.all_known_tool_names", + return_value=frozenset({"run_block", "web_fetch", "bash_exec"}), + ) + def test_allowlist_keeps_only_matching(self, _mock_names): + """Explicit allowlist (tools_exclude=False) keeps only listed tools.""" + from backend.copilot.baseline.service import _filter_tools_by_permissions + from backend.copilot.permissions import CopilotPermissions + + tools = [ + _make_tool("run_block"), + _make_tool("web_fetch"), + _make_tool("bash_exec"), + ] + perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False) + result = _filter_tools_by_permissions(tools, perms) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_fetch" + + @patch( + "backend.copilot.permissions.all_known_tool_names", + return_value=frozenset({"run_block", "web_fetch", "bash_exec"}), + ) + def test_blacklist_excludes_listed(self, _mock_names): + """Blacklist (tools_exclude=True) removes only the listed tools.""" + from backend.copilot.baseline.service import _filter_tools_by_permissions + from backend.copilot.permissions import CopilotPermissions + + tools = [ + _make_tool("run_block"), + _make_tool("web_fetch"), + _make_tool("bash_exec"), + ] + perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True) + result = _filter_tools_by_permissions(tools, perms) + names = [t["function"]["name"] for t in result] + assert "bash_exec" not in names + assert "run_block" in names + assert "web_fetch" in names + assert len(result) == 2 + + @patch( + "backend.copilot.permissions.all_known_tool_names", + return_value=frozenset({"run_block", "web_fetch", "bash_exec"}), + ) + def test_unknown_tool_name_filtered_out(self, _mock_names): + """A tool whose name is not in all_known_tool_names is dropped.""" + from backend.copilot.baseline.service import _filter_tools_by_permissions + from backend.copilot.permissions import CopilotPermissions + + tools = [_make_tool("run_block"), _make_tool("unknown_tool")] + perms = CopilotPermissions(tools=["run_block"], tools_exclude=False) + result = _filter_tools_by_permissions(tools, perms) + names = [t["function"]["name"] for t in result] + assert "unknown_tool" not in names + assert names == ["run_block"] + + +# ---- _prepare_baseline_attachments tests ---- # + + +class TestPrepareBaselineAttachments: + """Tests for _prepare_baseline_attachments.""" + + @pytest.mark.asyncio + async def test_empty_file_ids(self): + """Empty file_ids returns empty hint and blocks.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp") + assert hint == "" + assert blocks == [] + + @pytest.mark.asyncio + async def test_empty_user_id(self): + """Empty user_id returns empty hint and blocks.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + hint, blocks = await _prepare_baseline_attachments( + ["file1"], "", "sess1", "/tmp" + ) + assert hint == "" + assert blocks == [] + + @pytest.mark.asyncio + async def test_image_file_returns_vision_blocks(self): + """A PNG image within size limits is returned as a base64 vision block.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + fake_info = AsyncMock() + fake_info.name = "photo.png" + fake_info.mime_type = "image/png" + fake_info.size_bytes = 1024 + + fake_manager = AsyncMock() + fake_manager.get_file_info = AsyncMock(return_value=fake_info) + fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA") + + with patch( + "backend.copilot.baseline.service.get_workspace_manager", + new=AsyncMock(return_value=fake_manager), + ): + hint, blocks = await _prepare_baseline_attachments( + ["fid1"], "user1", "sess1", "/tmp/workdir" + ) + + assert len(blocks) == 1 + assert blocks[0]["type"] == "image" + assert blocks[0]["source"]["media_type"] == "image/png" + assert blocks[0]["source"]["type"] == "base64" + assert "photo.png" in hint + assert "embedded as image" in hint + + @pytest.mark.asyncio + async def test_non_image_file_saved_to_working_dir(self, tmp_path): + """A non-image file is written to working_dir.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + fake_info = AsyncMock() + fake_info.name = "data.csv" + fake_info.mime_type = "text/csv" + fake_info.size_bytes = 42 + + fake_manager = AsyncMock() + fake_manager.get_file_info = AsyncMock(return_value=fake_info) + fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b") + + with patch( + "backend.copilot.baseline.service.get_workspace_manager", + new=AsyncMock(return_value=fake_manager), + ): + hint, blocks = await _prepare_baseline_attachments( + ["fid1"], "user1", "sess1", str(tmp_path) + ) + + assert blocks == [] + assert "data.csv" in hint + assert "saved to" in hint + saved = tmp_path / "data.csv" + assert saved.exists() + assert saved.read_bytes() == b"col1,col2\na,b" + + @pytest.mark.asyncio + async def test_file_not_found_skipped(self): + """When get_file_info returns None the file is silently skipped.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + fake_manager = AsyncMock() + fake_manager.get_file_info = AsyncMock(return_value=None) + + with patch( + "backend.copilot.baseline.service.get_workspace_manager", + new=AsyncMock(return_value=fake_manager), + ): + hint, blocks = await _prepare_baseline_attachments( + ["missing_id"], "user1", "sess1", "/tmp" + ) + + assert hint == "" + assert blocks == [] + + @pytest.mark.asyncio + async def test_workspace_manager_error(self): + """When get_workspace_manager raises, returns empty results.""" + from backend.copilot.baseline.service import _prepare_baseline_attachments + + with patch( + "backend.copilot.baseline.service.get_workspace_manager", + new=AsyncMock(side_effect=RuntimeError("connection failed")), + ): + hint, blocks = await _prepare_baseline_attachments( + ["fid1"], "user1", "sess1", "/tmp" + ) + + assert hint == "" + assert blocks == [] + + +_COST_MISSING = object() + + +def _make_usage_chunk( + *, + prompt_tokens: int = 0, + completion_tokens: int = 0, + cost: float | str | None | object = _COST_MISSING, + cached_tokens: int | None = None, + cache_creation_input_tokens: int | None = None, +): + """Build a mock streaming chunk carrying usage (and optionally cost). + + Provider-specific fields (``cost`` on usage, ``cache_creation_input_tokens`` + on prompt_tokens_details) are set on ``model_extra`` because that's where + the baseline helper reads them from (typed ``CompletionUsage.model_extra`` + rather than ``getattr``). Pass ``cost=None`` to emit an explicit-null cost + key; omit ``cost`` entirely to leave the key absent. + """ + chunk = MagicMock() + chunk.choices = [] + chunk.usage = MagicMock() + chunk.usage.prompt_tokens = prompt_tokens + chunk.usage.completion_tokens = completion_tokens + usage_extras: dict[str, float | str | None] = {} + if cost is not _COST_MISSING: + usage_extras["cost"] = cost # type: ignore[assignment] + chunk.usage.model_extra = usage_extras + + if cached_tokens is not None or cache_creation_input_tokens is not None: + # Build a real ``PromptTokensDetails`` so ``getattr(ptd, + # "cache_write_tokens", None)`` returns ``None`` on this SDK version + # (rather than a truthy MagicMock attribute) and the extraction + # helper's typed-attr vs model_extra fallback resolves correctly. + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": cached_tokens or 0}) + if cache_creation_input_tokens is not None: + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = cache_creation_input_tokens + chunk.usage.prompt_tokens_details = ptd + else: + chunk.usage.prompt_tokens_details = None + + return chunk + + +def _make_stream_mock(*chunks): + """Build an async streaming response mock that yields *chunks* in order.""" + stream = MagicMock() + stream.close = AsyncMock() + + async def aiter(): + for c in chunks: + yield c + + stream.__aiter__ = lambda self: aiter() + return stream + + +class TestBaselineCostExtraction: + """Tests for ``usage.cost`` extraction in ``_baseline_llm_caller``. + + Cost is read from the OpenRouter ``usage.cost`` field on the final + streaming chunk when the request body includes ``usage: {include: true}`` + (handled by the baseline service via ``extra_body``). + """ + + @pytest.mark.asyncio + async def test_cost_usd_extracted_from_usage_chunk(self): + """state.cost_usd is set from chunk.usage.cost when present.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk( + prompt_tokens=1000, completion_tokens=200, cost=0.0123 + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.0123) + + @pytest.mark.asyncio + async def test_cost_usd_accumulates_across_calls(self): + """cost_usd accumulates when _baseline_llm_caller is called multiple times.""" + state = _BaselineStreamState(model="gpt-4o-mini") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[ + _make_stream_mock(_make_usage_chunk(prompt_tokens=500, cost=0.01)), + _make_stream_mock(_make_usage_chunk(prompt_tokens=600, cost=0.02)), + ] + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "first"}], + tools=[], + state=state, + ) + await _baseline_llm_caller( + messages=[{"role": "user", "content": "second"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.03) + + @pytest.mark.asyncio + async def test_cost_usd_accepts_string_value(self): + """OpenRouter may emit cost as a string — it should still parse.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost="0.005") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_cost_usd_none_when_usage_cost_missing(self): + """state.cost_usd stays None when the usage chunk lacks a cost field.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=1000, completion_tokens=500) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + # Token accumulators are still populated so the caller can log them. + assert state.turn_prompt_tokens == 1000 + assert state.turn_completion_tokens == 500 + + @pytest.mark.asyncio + async def test_invalid_cost_string_leaves_cost_none(self): + """A non-numeric cost value is rejected without raising.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost="not-a-number") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_negative_cost_is_ignored(self): + """Guard against negative cost values (shouldn't happen but be safe).""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost=-0.01) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_explicit_null_cost_is_logged_and_ignored(self, caplog): + """`{"cost": null}` is rejected and logged (not silently dropped).""" + state = _BaselineStreamState(model="openrouter/auto") + chunk = _make_usage_chunk(prompt_tokens=10, cost=None) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + caplog.at_level("ERROR", logger="backend.copilot.baseline.service"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + assert any( + "usage.cost is present but null" in rec.message for rec in caplog.records + ) + + @pytest.mark.asyncio + async def test_cost_not_captured_when_stream_raises_mid_chunk(self): + """If the stream aborts before emitting the usage chunk there is no cost.""" + state = _BaselineStreamState(model="gpt-4o-mini") + + stream = MagicMock() + stream.close = AsyncMock() + + async def failing_aiter(): + raise RuntimeError("stream error") + yield # make it an async generator + + stream.__aiter__ = lambda self: failing_aiter() + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=stream) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + pytest.raises(RuntimeError, match="stream error"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + # Stream aborted before yielding the usage chunk — cost stays None. + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_no_cost_when_api_call_raises_before_stream(self): + """The helper is safe when the create() call itself raises.""" + state = _BaselineStreamState(model="gpt-4o-mini") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=RuntimeError("connection refused") + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + pytest.raises(RuntimeError, match="connection refused"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_cache_tokens_extracted_from_usage_details(self): + """cache tokens are extracted from prompt_tokens_details.cached_tokens.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=800, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.turn_cache_read_tokens == 800 + assert state.turn_prompt_tokens == 1000 + + @pytest.mark.asyncio + async def test_cache_creation_tokens_extracted_from_usage_details(self): + """cache_creation_input_tokens is extracted from prompt_tokens_details.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=0, + cache_creation_input_tokens=500, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.turn_cache_creation_tokens == 500 + + @pytest.mark.asyncio + async def test_token_accumulators_track_across_multiple_calls(self): + """Token accumulators grow correctly across multiple _baseline_llm_caller calls.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[ + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1000, completion_tokens=200) + ), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1100, completion_tokens=300) + ), + ] + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + await _baseline_llm_caller( + messages=[{"role": "user", "content": "follow up"}], + tools=[], + state=state, + ) + + # No usage.cost on either chunk → cost stays None, tokens still accumulate. + assert state.cost_usd is None + assert state.turn_prompt_tokens == 2100 + assert state.turn_completion_tokens == 500 + + @pytest.mark.parametrize( + "tools", + [ + pytest.param([], id="no_tools"), + pytest.param([_make_tool("search")], id="with_tools"), + ], + ) + @pytest.mark.asyncio + async def test_baseline_requests_usage_include_extra_body( + self, tools: list[ChatCompletionToolParam] + ): + """The baseline call must pass extra_body={'usage': {'include': True}}. + + This guards the contract with OpenRouter that triggers inclusion of + the authoritative cost on the final usage chunk. Without it the + rate-limit counter stays at zero. Exercise both the no-tools and + tool-calling branches so a regression in either path trips the test. + """ + state = _BaselineStreamState(model="gpt-4o-mini") + create_mock = AsyncMock(return_value=_make_stream_mock()) + mock_client = MagicMock() + mock_client.chat.completions.create = create_mock + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=tools, + state=state, + ) + + create_mock.assert_awaited_once() + await_args = create_mock.await_args + assert await_args is not None + assert await_args.kwargs["extra_body"] == {"usage": {"include": True}} + assert await_args.kwargs["stream_options"] == {"include_usage": True} + + +class TestMidLoopPendingFlushOrdering: + """Regression test for the mid-loop pending drain ordering invariant. + + ``_baseline_conversation_updater`` records assistant+tool entries from + each tool-call round into ``state.session_messages``; the finally block + of ``stream_chat_completion_baseline`` batch-flushes them into + ``session.messages`` at the end of the turn. + + The mid-loop pending drain appends pending user messages directly to + ``session.messages``. Without flushing ``state.session_messages`` first, + the pending user message lands BEFORE the preceding round's assistant+ + tool entries in the final persisted ``session.messages`` — which + produces a malformed tool-call/tool-result ordering on the next turn's + replay. + + This test documents the invariant by replaying the production flush + sequence against an in-memory state. + """ + + def test_flush_then_append_preserves_chronological_order(self): + """Mid-loop drain must flush state.session_messages before appending + the pending user message, so the final order matches the + chronological execution order. + """ + # Initial state: user turn already appended by maybe_append_user_message + session_messages: list[ChatMessage] = [ + ChatMessage(role="user", content="original user turn"), + ] + state = _BaselineStreamState() + + # Round 1 completes: conversation_updater buffers assistant+tool + # entries into state.session_messages (but does NOT write to + # session.messages yet). + builder = TranscriptBuilder() + builder.append_user("original user turn") + response = LLMLoopResponse( + response_text="calling search", + tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult( + tool_call_id="tc_1", tool_name="search", content="search output" + ), + ] + openai_messages: list = [] + _baseline_conversation_updater( + openai_messages, + response, + tool_results=tool_results, + transcript_builder=builder, + state=state, + model="test-model", + ) + # state.session_messages should now hold the round-1 assistant + tool + assert len(state.session_messages) == 2 + assert state.session_messages[0].role == "assistant" + assert state.session_messages[1].role == "tool" + + # --- Mid-loop pending drain (production code pattern) --- + # Flush first, THEN append pending. This is the ordering fix. + for _buffered in state.session_messages: + session_messages.append(_buffered) + state.session_messages.clear() + session_messages.append( + ChatMessage(role="user", content="pending mid-loop message") + ) + + # Round 2 completes: new assistant+tool entries buffer again. + response2 = LLMLoopResponse( + response_text="another call", + tool_calls=[LLMToolCall(id="tc_2", name="calc", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results2 = [ + ToolCallResult( + tool_call_id="tc_2", tool_name="calc", content="calc output" + ), + ] + _baseline_conversation_updater( + openai_messages, + response2, + tool_results=tool_results2, + transcript_builder=builder, + state=state, + model="test-model", + ) + + # --- Finally-block flush (end of turn) --- + for msg in state.session_messages: + session_messages.append(msg) + + # Assert chronological order: original user, round-1 assistant, + # round-1 tool, pending user, round-2 assistant, round-2 tool. + assert [m.role for m in session_messages] == [ + "user", + "assistant", + "tool", + "user", + "assistant", + "tool", + ] + assert session_messages[0].content == "original user turn" + assert session_messages[3].content == "pending mid-loop message" + # The assistant message carrying tool_call tc_1 must be immediately + # followed by its tool result — no user message interposed. + assert session_messages[1].role == "assistant" + assert session_messages[1].tool_calls is not None + assert session_messages[1].tool_calls[0]["id"] == "tc_1" + assert session_messages[2].role == "tool" + assert session_messages[2].tool_call_id == "tc_1" + # Same invariant for the round after the pending user. + assert session_messages[4].tool_calls is not None + assert session_messages[4].tool_calls[0]["id"] == "tc_2" + assert session_messages[5].tool_call_id == "tc_2" + + def test_flushed_assistant_text_len_prevents_duplicate_final_text(self): + """After mid-loop drain clears state.session_messages, the finally + block must not re-append assistant text from rounds already flushed. + + ``state.assistant_text`` accumulates ALL rounds' text, but + ``state.session_messages`` only holds entries from rounds AFTER the + last mid-loop flush. Without ``_flushed_assistant_text_len``, the + ``finally`` block's ``startswith(recorded)`` check fails because + ``recorded`` only covers post-flush rounds, and the full + ``assistant_text`` is appended — duplicating pre-flush rounds. + """ + state = _BaselineStreamState() + session_messages: list[ChatMessage] = [ + ChatMessage(role="user", content="user turn"), + ] + + # Simulate round 1 text accumulation (as _bound_llm_caller does) + state.assistant_text += "calling search" + + # Round 1 conversation_updater buffers structured entries + builder = TranscriptBuilder() + builder.append_user("user turn") + response1 = LLMLoopResponse( + response_text="calling search", + tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + _baseline_conversation_updater( + [], + response1, + tool_results=[ + ToolCallResult( + tool_call_id="tc_1", tool_name="search", content="result" + ) + ], + transcript_builder=builder, + state=state, + model="test-model", + ) + + # Mid-loop drain: flush + clear + record flushed text length + for _buffered in state.session_messages: + session_messages.append(_buffered) + state.session_messages.clear() + state._flushed_assistant_text_len = len(state.assistant_text) + session_messages.append(ChatMessage(role="user", content="pending message")) + + # Simulate round 2 text accumulation + state.assistant_text += "final answer" + + # Round 2: natural finish (no tool calls → no session_messages entry) + + # --- Finally block logic (production code) --- + for msg in state.session_messages: + session_messages.append(msg) + + final_text = state.assistant_text[state._flushed_assistant_text_len :] + if state.session_messages: + recorded = "".join( + m.content or "" for m in state.session_messages if m.role == "assistant" + ) + if final_text.startswith(recorded): + final_text = final_text[len(recorded) :] + if final_text.strip(): + session_messages.append(ChatMessage(role="assistant", content=final_text)) + + # The final assistant message should only contain round-2 text, + # not the round-1 text that was already flushed mid-loop. + assistant_msgs = [m for m in session_messages if m.role == "assistant"] + # Round-1 structured assistant (from mid-loop flush) + assert assistant_msgs[0].content == "calling search" + assert assistant_msgs[0].tool_calls is not None + # Round-2 final text (from finally block) + assert assistant_msgs[1].content == "final answer" + assert assistant_msgs[1].tool_calls is None + # Crucially: only 2 assistant messages, not 3 (no duplicate) + assert len(assistant_msgs) == 2 + + +class TestBuilderContextSplit: + """Cross-helper composition: the guide must land in the system prompt via + ``build_builder_system_prompt_suffix`` and NOT in the per-turn user prefix + via ``build_builder_context_turn_prefix``. + + The baseline service composes these two blocks on each turn, so a drift + here (guide leaking into both, or missing from both) would kill Claude's + prompt-cache hit rate for builder sessions. + """ + + @pytest.mark.asyncio + async def test_guide_lives_in_system_prompt_not_user_message(self): + from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, + ) + from backend.copilot.model import ChatSession + + session = MagicMock(spec=ChatSession) + session.session_id = "s" + session.metadata = MagicMock() + session.metadata.builder_graph_id = "graph-1" + + agent_json = { + "id": "graph-1", + "name": "Demo", + "version": 7, + "nodes": [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + } + ], + "links": [], + } + guide_body = "# UNIQUE_GUIDE_MARKER body" + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent_json), + ), + patch( + "backend.copilot.builder_context._load_guide", + return_value=guide_body, + ), + ): + suffix = await build_builder_system_prompt_suffix(session) + prefix = await build_builder_context_turn_prefix(session, "user-1") + + # System prompt suffix carries and the guide. + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert guide_body in suffix + # Dynamic bits must NOT be in the suffix — otherwise renames and + # cross-graph sessions invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "Demo" not in suffix + + # Per-turn prefix carries with the full live + # snapshot (id, name, version, nodes) but NEVER the guide. + assert f"<{BUILDER_CONTEXT_TAG}>" in prefix + assert 'id="graph-1"' in prefix + assert 'name="Demo"' in prefix + assert 'version="7"' in prefix + assert guide_body not in prefix + assert "" not in prefix + + # Guide appears in the combined on-the-wire payload exactly ONCE. + combined = suffix + "\n\n" + prefix + assert combined.count(guide_body) == 1 + + +class TestApplyPromptCacheMarkers: + """Tests for _apply_prompt_cache_markers — Anthropic ephemeral + cache_control markers on baseline OpenRouter requests.""" + + def test_system_message_converted_to_content_blocks(self): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["role"] == "system" + assert cached_messages[0]["content"] == [ + { + "type": "text", + "text": "You are helpful.", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + # User message must be untouched. + assert cached_messages[1] == {"role": "user", "content": "hello"} + + def test_system_message_preserves_unknown_fields(self): + # Future-proofing: a system message with extra keys (e.g. "name") must + # keep them after the content-blocks conversion. + messages = [ + {"role": "system", "content": "sys", "name": "developer"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["name"] == "developer" + assert cached_messages[0]["role"] == "system" + + def test_last_tool_gets_cache_control(self): + tools = [ + {"type": "function", "function": {"name": "a"}}, + {"type": "function", "function": {"name": "b"}}, + ] + + cached_tools = _mark_tools_with_cache_control(tools) + + assert "cache_control" not in cached_tools[0] + assert cached_tools[-1]["cache_control"] == { + "type": "ephemeral", + "ttl": "1h", + } + # Last tool's other fields preserved. + assert cached_tools[-1]["function"] == {"name": "b"} + + def test_does_not_mutate_input(self): + messages = [{"role": "system", "content": "sys"}] + tools = [{"type": "function", "function": {"name": "a"}}] + + _mark_system_message_with_cache_control(messages) + _mark_tools_with_cache_control(tools) + + assert messages == [{"role": "system", "content": "sys"}] + assert tools == [{"type": "function", "function": {"name": "a"}}] + + def test_no_system_message_safe(self): + messages = [{"role": "user", "content": "hi"}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages == messages + + def test_empty_tools_safe(self): + assert _mark_tools_with_cache_control([]) == [] + + def test_non_string_system_content_left_untouched(self): + # If the content is already a list of blocks (e.g. caller pre-marked), + # the helper must not overwrite it. + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + messages = [{"role": "system", "content": pre_marked}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages[0]["content"] == pre_marked + + def test_is_anthropic_model_matches_claude_and_anthropic_prefix(self): + assert _is_anthropic_model("anthropic/claude-sonnet-4-6") + assert _is_anthropic_model("claude-3-5-sonnet-20241022") + assert _is_anthropic_model("anthropic.claude-3-5-sonnet-20241022-v2:0") + assert _is_anthropic_model("ANTHROPIC/Claude-Opus") # case insensitive + + def test_is_anthropic_model_rejects_other_providers(self): + assert not _is_anthropic_model("openai/gpt-4o") + assert not _is_anthropic_model("openai/gpt-5") + assert not _is_anthropic_model("google/gemini-2.5-pro") + assert not _is_anthropic_model("xai/grok-4") + assert not _is_anthropic_model("meta-llama/llama-3.3-70b-instruct") + + def test_cache_control_uses_configured_ttl(self, monkeypatch): + """TTL comes from ChatConfig.baseline_prompt_cache_ttl — defaults + to 1h so the static prefix (system + tools) stays warm across + workspace users past the 5-min default window.""" + from backend.copilot.baseline import service as bsvc + + assert bsvc.config.baseline_prompt_cache_ttl == "1h" + cc = bsvc._fresh_ephemeral_cache_control() + assert cc == {"type": "ephemeral", "ttl": "1h"} + monkeypatch.setattr(bsvc.config, "baseline_prompt_cache_ttl", "5m") + assert bsvc._fresh_ephemeral_cache_control() == { + "type": "ephemeral", + "ttl": "5m", + } + + def test_fresh_helpers_return_distinct_objects(self): + """Regression guard: the `_fresh_*` helpers must return a NEW dict + on every call. A future refactor returning a module-level constant + would silently reintroduce the shared-mutable-state bug flagged + during earlier review cycles.""" + assert _fresh_ephemeral_cache_control() is not _fresh_ephemeral_cache_control() + assert ( + _fresh_anthropic_caching_headers() is not _fresh_anthropic_caching_headers() + ) + + def test_extract_cache_creation_tokens_openrouter_typed_attr(self): + """Newer ``openai-python`` declares ``cache_write_tokens`` as a + typed attribute on ``PromptTokensDetails`` — it no longer lands in + ``model_extra``. Verified empirically against the production + openai==1.113 installed in this venv: OpenRouter streaming + response populates ``ptd.cache_write_tokens`` directly while + ``ptd.model_extra`` is ``{}``. + """ + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate( + { + "audio_tokens": 0, + "cached_tokens": 0, + "cache_write_tokens": 4432, + "video_tokens": 0, + } + ) + assert getattr(ptd, "cache_write_tokens", None) == 4432 + assert _extract_cache_creation_tokens(ptd) == 4432 + + def test_extract_cache_creation_tokens_openrouter_model_extra(self): + """Older SDKs that don't yet declare ``cache_write_tokens`` as a + typed field leave it in ``model_extra`` — the helper must still + find it there.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + # Force the value into model_extra (simulates the old SDK shape + # where the field wasn't typed yet). + if ptd.model_extra is None: + # Pydantic v2 sometimes exposes __pydantic_extra__ as None when + # extras are disabled; initialise to a dict to mutate safely. + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_write_tokens"] = 7777 + assert _extract_cache_creation_tokens(ptd) == 7777 + + def test_extract_cache_creation_tokens_anthropic_native_field(self): + """Direct Anthropic API uses ``cache_creation_input_tokens`` — + falls through as the final path when neither + ``cache_write_tokens`` typed attr nor model_extra entry exists.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = 2048 + assert _extract_cache_creation_tokens(ptd) == 2048 + + def test_extract_cache_creation_tokens_absent(self): + """Neither provider field present → 0 (non-Anthropic routes or + cache-miss responses).""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + assert _extract_cache_creation_tokens(ptd) == 0 + + def test_build_cached_system_message_applies_cache_control(self): + """The single-message helper wraps the string content in a text block + with an ephemeral cache_control marker.""" + out = _build_cached_system_message({"role": "system", "content": "hi"}) + assert out["role"] == "system" + assert out["content"] == [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + + def test_build_cached_system_message_preserves_extra_fields(self): + """Unknown keys (e.g. ``name``) survive the transformation.""" + out = _build_cached_system_message( + {"role": "system", "content": "sys", "name": "dev"} + ) + assert out["name"] == "dev" + assert out["role"] == "system" + + def test_build_cached_system_message_non_string_passthrough(self): + """Pre-marked list content is returned as-is (shallow-copied).""" + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + out = _build_cached_system_message({"role": "system", "content": pre_marked}) + assert out["content"] is pre_marked + + @pytest.mark.asyncio + async def test_baseline_llm_caller_memoises_cached_system_message(self): + """The cached system dict is built once and reused across rounds. + + Guards against the perf regression where the entire (growing) + ``messages`` list was copied on every tool-call iteration just to + mark the static system prompt. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[_make_stream_mock(chunk), _make_stream_mock(chunk)] + ) + + messages: list[dict] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + first_cached = state.cached_system_message + assert first_cached is not None + # Simulate the tool-call loop growing ``messages`` between rounds. + messages.append({"role": "assistant", "content": "ok"}) + messages.append({"role": "user", "content": "follow up"}) + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + # Same dict instance reused — not rebuilt per round. + assert state.cached_system_message is first_cached + + # Second call's first message is the memoised system dict (not a new copy). + second_call_messages = mock_client.chat.completions.create.call_args_list[1][1][ + "messages" + ] + assert second_call_messages[0] is first_cached + # And the tail messages were spliced in, not re-copied. + assert second_call_messages[1] is messages[1] + assert second_call_messages[-1] is messages[-1] + + @pytest.mark.asyncio + async def test_baseline_llm_caller_skips_memoisation_for_non_anthropic(self): + """Non-Anthropic routes pass messages through unmodified — no cache + dict is built, no list splicing happens.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + messages: list[dict] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + assert state.cached_system_message is None + # The exact same list object reaches the provider (no copy needed). + call_messages = mock_client.chat.completions.create.call_args[1]["messages"] + assert call_messages is messages + + +def _make_delta_chunk( + *, + content: str | None = None, + reasoning: str | None = None, + reasoning_details: list | None = None, + reasoning_content: str | None = None, + tool_calls: list | None = None, +): + """Build a streaming chunk with a configurable ``delta`` payload. + + The ``delta`` is a real ``ChoiceDelta`` pydantic instance so OpenRouter + extension fields land on ``delta.model_extra`` — which is how + :class:`OpenRouterDeltaExtension` reads them in production. Using a + raw ``MagicMock`` here would leave ``model_extra`` unset and silently + skip the reasoning parser. ``tool_calls`` (when provided) must be + ``MagicMock`` entries compatible with the service's streaming loop; + they're set on the delta via ``object.__setattr__`` because pydantic + would otherwise reject the non-schema types. + """ + from openai.types.chat.chat_completion_chunk import ChoiceDelta + + payload: dict = {"role": "assistant"} + if content is not None: + payload["content"] = content + if reasoning is not None: + payload["reasoning"] = reasoning + if reasoning_content is not None: + payload["reasoning_content"] = reasoning_content + if reasoning_details is not None: + payload["reasoning_details"] = reasoning_details + delta = ChoiceDelta.model_validate(payload) + # ChoiceDelta's tool_calls schema expects OpenAI-typed entries; bypass + # validation so tests can use MagicMocks that mimic the streaming shape. + if tool_calls is not None: + object.__setattr__(delta, "tool_calls", tool_calls) + + chunk = MagicMock() + chunk.usage = None + choice = MagicMock() + choice.delta = delta + chunk.choices = [choice] + return chunk + + +def _make_tool_call_delta(*, index: int, call_id: str, name: str, arguments: str): + """Build a ``delta.tool_calls[i]`` entry for streaming tool-use.""" + tc = MagicMock() + tc.index = index + tc.id = call_id + function = MagicMock() + function.name = name + function.arguments = arguments + tc.function = function + return tc + + +class TestBaselineReasoningStreaming: + """End-to-end reasoning event emission through ``_baseline_llm_caller``.""" + + @pytest.mark.asyncio + async def test_reasoning_then_text_emits_paired_events(self): + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="thinking..."), + _make_delta_chunk(reasoning=" more"), + _make_delta_chunk(content="final answer"), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningDelta" in types + assert "StreamReasoningEnd" in types + + # Reasoning must close before text opens — AI SDK v5 rejects + # interleaved reasoning / text parts. + reason_end = types.index("StreamReasoningEnd") + text_start = types.index("StreamTextStart") + assert reason_end < text_start + + # All reasoning deltas share a single block id; the text block uses + # a fresh id after the reasoning-end rotation. + reasoning_ids = { + e.id + for e in state.pending_events + if isinstance( + e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd) + ) + } + text_ids = { + e.id + for e in state.pending_events + if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd)) + } + assert len(reasoning_ids) == 1 + assert len(text_ids) == 1 + assert reasoning_ids.isdisjoint(text_ids) + + combined = "".join( + e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta) + ) + assert combined == "thinking... more" + + @pytest.mark.asyncio + async def test_reasoning_then_tool_call_closes_reasoning_first(self): + """A tool_call arriving mid-reasoning must close the reasoning block + before the tool-use is flushed — AI SDK v5 treats reasoning and + tool-use as distinct UI parts and rejects interleaving.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="deliberating..."), + _make_delta_chunk( + tool_calls=[ + _make_tool_call_delta( + index=0, + call_id="call_1", + name="search", + arguments='{"q":"x"}', + ) + ], + ), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + response = await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + # A reasoning-end must have been emitted — this is the tool_calls + # branch's responsibility, not the stream-end cleanup. + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + + # The tool_call was collected — confirms the tool-use path executed + # after reasoning closed (rather than silently dropping the tool). + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "search" + + # No text events — this stream had no content deltas. + assert "StreamTextStart" not in types + + @pytest.mark.asyncio + async def test_reasoning_closed_on_mid_stream_exception(self): + """Regression guard: an exception during the streaming loop must + still emit ``StreamReasoningEnd`` (and ``StreamTextEnd`` when a + text block is open) before ``StreamFinishStep`` — the frontend + collapse relies on matched start/end pairs, and the outer handler + no longer patches these after-the-fact.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + async def failing_stream(): + yield _make_delta_chunk(reasoning="thinking...") + raise RuntimeError("boom") + + stream = MagicMock() + stream.close = AsyncMock() + stream.__aiter__ = lambda self: failing_stream() + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=stream) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + with pytest.raises(RuntimeError): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + # The reasoning block was opened, the exception fired, and the + # finally block must have closed it before emitting the finish + # step. + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + assert "StreamFinishStep" in types + assert types.index("StreamReasoningEnd") < types.index("StreamFinishStep") + # Emitter is reset so a retried round starts with fresh ids. + assert state.reasoning_emitter.is_open is False + + @pytest.mark.asyncio + async def test_reasoning_param_sent_on_anthropic_routes(self): + """Anthropic route gets ``reasoning.max_tokens`` on the request.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" in extra_body + assert extra_body["reasoning"]["max_tokens"] > 0 + + @pytest.mark.asyncio + async def test_reasoning_param_absent_on_non_anthropic_routes(self): + """Non-Anthropic routes (e.g. OpenAI) must not receive ``reasoning``.""" + state = _BaselineStreamState(model="openai/gpt-4o") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_reasoning_only_stream_still_closes_block(self): + """Regression: a stream with only reasoning (no text, no tool_call) + must still emit a matching ``reasoning-end`` at stream close so the + frontend Reasoning collapse finalises. Exercised here against + ``_baseline_llm_caller`` to cover the emitter's integration with + the finally-block, not just the unit emitter in reasoning_test.py. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="just thinking"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + # No text was produced — no text events should be emitted. + assert "StreamTextStart" not in types + assert "StreamTextDelta" not in types + + @pytest.mark.asyncio + async def test_reasoning_param_suppressed_when_thinking_tokens_zero(self): + """Operator kill switch: setting ``claude_agent_max_thinking_tokens`` + to 0 removes the ``reasoning`` fragment from ``extra_body`` even on + an Anthropic route. Restores the zero-disables behaviour the old + ``baseline_reasoning_max_tokens`` config used to provide.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + patch( + "backend.copilot.baseline.service.config.claude_agent_max_thinking_tokens", + 0, + ), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_reasoning_persists_to_state_session_messages(self): + """Integration guard: ``_BaselineStreamState.__post_init__`` wires + the emitter to ``state.session_messages``, so reasoning deltas + flowing through ``_baseline_llm_caller`` must produce a + ``role="reasoning"`` row on the state's session list. Catches + regressions where the wiring silently breaks (e.g. a refactor + passes the wrong list reference).""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="first "), + _make_delta_chunk(reasoning="thought"), + _make_delta_chunk(content="answer"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"] + assert len(reasoning_rows) == 1 + assert reasoning_rows[0].content == "first thought" diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py new file mode 100644 index 0000000000..8d6fb50a53 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -0,0 +1,756 @@ +"""Integration tests for baseline transcript flow. + +Exercises the real helpers in ``baseline/service.py`` that restore, +validate, load, append to, backfill, and upload the CLI session. +Storage is mocked via ``download_transcript`` / ``upload_transcript`` +patches; no network access is required. +""" + +import json as stdlib_json +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot.baseline.service import ( + _append_gap_to_builder, + _load_prior_transcript, + _record_turn_to_transcript, + _resolve_baseline_model, + _upload_final_transcript, + should_upload_transcript, +) +from backend.copilot.model import ChatMessage +from backend.copilot.service import config +from backend.copilot.transcript import ( + STOP_REASON_END_TURN, + STOP_REASON_TOOL_USE, + TranscriptDownload, +) +from backend.copilot.transcript_builder import TranscriptBuilder +from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult + + +def _make_transcript_content(*roles: str) -> str: + """Build a minimal valid JSONL transcript from role names.""" + lines = [] + parent = "" + for i, role in enumerate(roles): + uid = f"uuid-{i}" + entry: dict = { + "type": role, + "uuid": uid, + "parentUuid": parent, + "message": { + "role": role, + "content": [{"type": "text", "text": f"{role} message {i}"}], + }, + } + if role == "assistant": + entry["message"]["id"] = f"msg_{i}" + entry["message"]["model"] = "test-model" + entry["message"]["type"] = "message" + entry["message"]["stop_reason"] = STOP_REASON_END_TURN + lines.append(stdlib_json.dumps(entry)) + parent = uid + return "\n".join(lines) + "\n" + + +def _make_session_messages(*roles: str) -> list[ChatMessage]: + """Build a list of ChatMessage objects matching the given roles.""" + return [ + ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles) + ] + + +class TestResolveBaselineModel: + """Baseline model resolution honours the per-request tier toggle.""" + + def test_advanced_tier_selects_advanced_model(self): + assert _resolve_baseline_model("advanced") == config.advanced_model + + def test_standard_tier_selects_default_model(self): + assert _resolve_baseline_model("standard") == config.model + + def test_none_tier_selects_default_model(self): + """Baseline users without a tier MUST keep the default (standard).""" + assert _resolve_baseline_model(None) == config.model + + def test_standard_and_advanced_models_differ(self): + """Advanced tier defaults to a different (Opus) model than standard.""" + assert config.model != config.advanced_model + + +class TestLoadPriorTranscript: + """``_load_prior_transcript`` wraps the CLI session restore + validate + load flow.""" + + @pytest.mark.asyncio + async def test_loads_fresh_transcript(self): + builder = TranscriptBuilder() + content = _make_transcript_content("user", "assistant") + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) + + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ): + covers, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant", "user"), + transcript_builder=builder, + ) + + assert covers is True + assert dl is not None + assert dl.message_count == 2 + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + + @pytest.mark.asyncio + async def test_fills_gap_when_transcript_is_behind(self): + """When transcript covers fewer messages than session, gap is filled from DB.""" + builder = TranscriptBuilder() + content = _make_transcript_content("user", "assistant") + # transcript covers 2 messages, session has 4 (plus current user turn = 5) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="baseline" + ) + + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ): + covers, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), + transcript_builder=builder, + ) + + assert covers is True + assert dl is not None + # 2 from transcript + 2 gap messages (user+assistant at positions 2,3) + assert builder.entry_count == 4 + + @pytest.mark.asyncio + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → upload is safe; the turn writes the first snapshot.""" + builder = TranscriptBuilder() + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=None), + ): + upload_safe, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant"), + transcript_builder=builder, + ) + + assert upload_safe is True + assert dl is None + assert builder.is_empty + + @pytest.mark.asyncio + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with a valid one is better.""" + builder = TranscriptBuilder() + restore = TranscriptDownload( + content=b'{"type":"progress","uuid":"a"}\n', + message_count=1, + mode="sdk", + ) + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ): + upload_safe, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant"), + transcript_builder=builder, + ) + + assert upload_safe is True + assert dl is None + assert builder.is_empty + + @pytest.mark.asyncio + async def test_download_exception_returns_false(self): + builder = TranscriptBuilder() + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(side_effect=RuntimeError("boom")), + ): + covers, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant"), + transcript_builder=builder, + ) + + assert covers is False + assert dl is None + assert builder.is_empty + + @pytest.mark.asyncio + async def test_zero_message_count_not_stale(self): + """When msg_count is 0 (unknown), gap detection is skipped.""" + builder = TranscriptBuilder() + restore = TranscriptDownload( + content=_make_transcript_content("user", "assistant").encode("utf-8"), + message_count=0, + mode="sdk", + ) + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ): + covers, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages(*["user"] * 20), + transcript_builder=builder, + ) + + assert covers is True + assert dl is not None + assert builder.entry_count == 2 + + +class TestUploadFinalTranscript: + """``_upload_final_transcript`` serialises and calls storage.""" + + @pytest.mark.asyncio + async def test_uploads_valid_transcript(self): + builder = TranscriptBuilder() + builder.append_user(content="hi") + builder.append_assistant( + content_blocks=[{"type": "text", "text": "hello"}], + model="test-model", + stop_reason=STOP_REASON_END_TURN, + ) + + upload_mock = AsyncMock(return_value=None) + with patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ): + await _upload_final_transcript( + user_id="user-1", + session_id="session-1", + transcript_builder=builder, + session_msg_count=2, + ) + + upload_mock.assert_awaited_once() + assert upload_mock.await_args is not None + call_kwargs = upload_mock.await_args.kwargs + assert call_kwargs["user_id"] == "user-1" + assert call_kwargs["session_id"] == "session-1" + assert call_kwargs["message_count"] == 2 + assert b"hello" in call_kwargs["content"] + + @pytest.mark.asyncio + async def test_skips_upload_when_builder_empty(self): + builder = TranscriptBuilder() + upload_mock = AsyncMock(return_value=None) + with patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ): + await _upload_final_transcript( + user_id="user-1", + session_id="session-1", + transcript_builder=builder, + session_msg_count=0, + ) + + upload_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_swallows_upload_exceptions(self): + """Upload failures should not propagate (flow continues for the user).""" + builder = TranscriptBuilder() + builder.append_user(content="hi") + builder.append_assistant( + content_blocks=[{"type": "text", "text": "hello"}], + model="test-model", + stop_reason=STOP_REASON_END_TURN, + ) + + with patch( + "backend.copilot.baseline.service.upload_transcript", + new=AsyncMock(side_effect=RuntimeError("storage unavailable")), + ): + # Should not raise. + await _upload_final_transcript( + user_id="user-1", + session_id="session-1", + transcript_builder=builder, + session_msg_count=2, + ) + + +class TestRecordTurnToTranscript: + """``_record_turn_to_transcript`` translates LLMLoopResponse → transcript.""" + + def test_records_final_assistant_text(self): + builder = TranscriptBuilder() + builder.append_user(content="hi") + + response = LLMLoopResponse( + response_text="hello there", + tool_calls=[], + raw_response=None, + ) + _record_turn_to_transcript( + response, + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + jsonl = builder.to_jsonl() + assert "hello there" in jsonl + assert STOP_REASON_END_TURN in jsonl + + def test_records_tool_use_then_tool_result(self): + """Anthropic ordering: assistant(tool_use) → user(tool_result).""" + builder = TranscriptBuilder() + builder.append_user(content="use a tool") + + response = LLMLoopResponse( + response_text=None, + tool_calls=[ + LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}') + ], + raw_response=None, + ) + tool_results = [ + ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi") + ] + _record_turn_to_transcript( + response, + tool_results, + transcript_builder=builder, + model="test-model", + ) + + # user, assistant(tool_use), user(tool_result) = 3 entries + assert builder.entry_count == 3 + jsonl = builder.to_jsonl() + assert STOP_REASON_TOOL_USE in jsonl + assert "tool_use" in jsonl + assert "tool_result" in jsonl + assert "call-1" in jsonl + + def test_records_nothing_on_empty_response(self): + builder = TranscriptBuilder() + builder.append_user(content="hi") + + response = LLMLoopResponse( + response_text=None, + tool_calls=[], + raw_response=None, + ) + _record_turn_to_transcript( + response, + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + + assert builder.entry_count == 1 + + def test_malformed_tool_args_dont_crash(self): + """Bad JSON in tool arguments falls back to {} without raising.""" + builder = TranscriptBuilder() + builder.append_user(content="hi") + + response = LLMLoopResponse( + response_text=None, + tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")], + raw_response=None, + ) + tool_results = [ + ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok") + ] + _record_turn_to_transcript( + response, + tool_results, + transcript_builder=builder, + model="test-model", + ) + + assert builder.entry_count == 3 + jsonl = builder.to_jsonl() + assert '"input":{}' in jsonl + + +class TestRoundTrip: + """End-to-end: load prior → append new turn → upload.""" + + @pytest.mark.asyncio + async def test_full_round_trip(self): + prior = _make_transcript_content("user", "assistant") + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) + + builder = TranscriptBuilder() + with patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ): + covers, _ = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant", "user"), + transcript_builder=builder, + ) + assert covers is True + assert builder.entry_count == 2 + + # New user turn. + builder.append_user(content="new question") + assert builder.entry_count == 3 + + # New assistant turn. + response = LLMLoopResponse( + response_text="new answer", + tool_calls=[], + raw_response=None, + ) + _record_turn_to_transcript( + response, + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + assert builder.entry_count == 4 + + # Upload. + upload_mock = AsyncMock(return_value=None) + with patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ): + await _upload_final_transcript( + user_id="user-1", + session_id="session-1", + transcript_builder=builder, + session_msg_count=4, + ) + + upload_mock.assert_awaited_once() + assert upload_mock.await_args is not None + uploaded = upload_mock.await_args.kwargs["content"] + assert b"new question" in uploaded + assert b"new answer" in uploaded + # Original content preserved in the round trip. + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded + + @pytest.mark.asyncio + async def test_backfill_append_guard(self): + """Backfill only runs when the last entry is not already assistant.""" + builder = TranscriptBuilder() + builder.append_user(content="hi") + + # Simulate the backfill guard from stream_chat_completion_baseline. + assistant_text = "partial text before error" + if builder.last_entry_type != "assistant": + builder.append_assistant( + content_blocks=[{"type": "text", "text": assistant_text}], + model="test-model", + stop_reason=STOP_REASON_END_TURN, + ) + + assert builder.last_entry_type == "assistant" + assert "partial text before error" in builder.to_jsonl() + + # Second invocation: the guard must prevent double-append. + initial_count = builder.entry_count + if builder.last_entry_type != "assistant": + builder.append_assistant( + content_blocks=[{"type": "text", "text": "duplicate"}], + model="test-model", + stop_reason=STOP_REASON_END_TURN, + ) + assert builder.entry_count == initial_count + + +class TestShouldUploadTranscript: + """``should_upload_transcript`` gates the final upload.""" + + def test_upload_allowed_for_user_with_coverage(self): + assert should_upload_transcript("user-1", True) is True + + def test_upload_skipped_when_no_user(self): + assert should_upload_transcript(None, True) is False + + def test_upload_skipped_when_empty_user(self): + assert should_upload_transcript("", True) is False + + def test_upload_skipped_without_coverage(self): + """Partial transcript must never clobber a more complete stored one.""" + assert should_upload_transcript("user-1", False) is False + + def test_upload_skipped_when_no_user_and_no_coverage(self): + assert should_upload_transcript(None, False) is False + + +class TestTranscriptLifecycle: + """End-to-end: restore → validate → build → upload. + + Simulates the full transcript lifecycle inside + ``stream_chat_completion_baseline`` by mocking the storage layer and + driving each step through the real helpers. + """ + + @pytest.mark.asyncio + async def test_full_lifecycle_happy_path(self): + """Fresh restore, append a turn, upload covers the session.""" + builder = TranscriptBuilder() + prior = _make_transcript_content("user", "assistant") + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) + + upload_mock = AsyncMock(return_value=None) + with ( + patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=restore), + ), + patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ), + ): + # --- 1. Restore & load prior session --- + covers, _ = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user", "assistant", "user"), + transcript_builder=builder, + ) + assert covers is True + + # --- 2. Append a new user turn + a new assistant response --- + builder.append_user(content="follow-up question") + _record_turn_to_transcript( + LLMLoopResponse( + response_text="follow-up answer", + tool_calls=[], + raw_response=None, + ), + tool_results=None, + transcript_builder=builder, + model="test-model", + ) + + # --- 3. Gate + upload --- + assert ( + should_upload_transcript(user_id="user-1", upload_safe=covers) is True + ) + await _upload_final_transcript( + user_id="user-1", + session_id="session-1", + transcript_builder=builder, + session_msg_count=4, + ) + + upload_mock.assert_awaited_once() + assert upload_mock.await_args is not None + uploaded = upload_mock.await_args.kwargs["content"] + assert b"follow-up question" in uploaded + assert b"follow-up answer" in uploaded + # Original prior-turn content preserved. + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded + + @pytest.mark.asyncio + async def test_lifecycle_stale_download_fills_gap(self): + """When transcript covers fewer messages, gap is filled rather than rejected.""" + builder = TranscriptBuilder() + # session has 5 msgs but stored transcript only covers 2 → gap filled. + stale = TranscriptDownload( + content=_make_transcript_content("user", "assistant").encode("utf-8"), + message_count=2, + mode="baseline", + ) + + upload_mock = AsyncMock(return_value=None) + with ( + patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=stale), + ), + patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ), + ): + covers, _ = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), + transcript_builder=builder, + ) + + assert covers is True + # Gap was filled: 2 from transcript + 2 gap messages + assert builder.entry_count == 4 + + @pytest.mark.asyncio + async def test_lifecycle_anonymous_user_skips_upload(self): + """Anonymous (user_id=None) → upload gate must return False.""" + builder = TranscriptBuilder() + builder.append_user(content="hi") + builder.append_assistant( + content_blocks=[{"type": "text", "text": "hello"}], + model="test-model", + stop_reason=STOP_REASON_END_TURN, + ) + + assert should_upload_transcript(user_id=None, upload_safe=True) is False + + @pytest.mark.asyncio + async def test_lifecycle_missing_download_still_uploads_new_content(self): + """No prior session → upload is safe; the turn writes the first snapshot.""" + builder = TranscriptBuilder() + upload_mock = AsyncMock(return_value=None) + with ( + patch( + "backend.copilot.baseline.service.download_transcript", + new=AsyncMock(return_value=None), + ), + patch( + "backend.copilot.baseline.service.upload_transcript", + new=upload_mock, + ), + ): + upload_safe, dl = await _load_prior_transcript( + user_id="user-1", + session_id="session-1", + session_messages=_make_session_messages("user"), + transcript_builder=builder, + ) + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial transcript snapshot. + assert upload_safe is True + assert dl is None + assert ( + should_upload_transcript(user_id="user-1", upload_safe=upload_safe) + is True + ) + + +# --------------------------------------------------------------------------- +# _append_gap_to_builder +# --------------------------------------------------------------------------- + + +class TestAppendGapToBuilder: + """``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries.""" + + def test_user_message_appended(self): + builder = TranscriptBuilder() + msgs = [ChatMessage(role="user", content="hello")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + assert builder.last_entry_type == "user" + + def test_assistant_text_message_appended(self): + builder = TranscriptBuilder() + msgs = [ + ChatMessage(role="user", content="q"), + ChatMessage(role="assistant", content="answer"), + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + assert "answer" in builder.to_jsonl() + + def test_assistant_with_tool_calls_appended(self): + """Assistant tool_calls are recorded as tool_use blocks in the transcript.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-1", + "type": "function", + "function": {"name": "my_tool", "arguments": '{"key":"val"}'}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "tool_use" in jsonl + assert "my_tool" in jsonl + assert "tc-1" in jsonl + + def test_assistant_invalid_json_args_uses_empty_dict(self): + """Malformed JSON in tool_call arguments falls back to {}.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-bad", + "type": "function", + "function": {"name": "bad_tool", "arguments": "not-json"}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert '"input":{}' in jsonl + + def test_assistant_empty_content_and_no_tools_uses_fallback(self): + """Assistant with no content and no tool_calls gets a fallback empty text block.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="assistant", content=None)] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "text" in jsonl + + def test_tool_role_with_tool_call_id_appended(self): + """Tool result messages are appended when tool_call_id is set.""" + builder = TranscriptBuilder() + # Need a preceding assistant tool_use entry + builder.append_user("use tool") + builder.append_assistant( + content_blocks=[ + {"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}} + ] + ) + msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 3 + assert "tool_result" in builder.to_jsonl() + + def test_tool_role_without_tool_call_id_skipped(self): + """Tool messages without tool_call_id are silently skipped.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 0 + + def test_tool_call_missing_function_key_uses_unknown_name(self): + """A tool_call dict with no 'function' key uses 'unknown' as the tool name.""" + builder = TranscriptBuilder() + # Tool call dict exists but 'function' sub-dict is missing entirely + msgs = [ + ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}]) + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "unknown" in jsonl diff --git a/autogpt_platform/backend/backend/copilot/builder_context.py b/autogpt_platform/backend/backend/copilot/builder_context.py new file mode 100644 index 0000000000..9f36350d1c --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context.py @@ -0,0 +1,217 @@ +"""Builder-session context helpers — split cacheable system prompt from +the volatile per-turn snapshot so Claude's prompt cache stays warm.""" + +from __future__ import annotations + +import logging +from typing import Any + +from backend.copilot.model import ChatSession +from backend.copilot.permissions import CopilotPermissions +from backend.copilot.tools.agent_generator import get_agent_as_json +from backend.copilot.tools.get_agent_building_guide import _load_guide + +logger = logging.getLogger(__name__) + + +BUILDER_CONTEXT_TAG = "builder_context" +BUILDER_SESSION_TAG = "builder_session" + + +# Tools hidden from builder-bound sessions: ``create_agent`` / +# ``customize_agent`` would mint a new graph (panel is bound to one), +# and ``get_agent_building_guide`` duplicates bytes already in the +# system-prompt suffix. Everything else (find_block, find_agent, …) +# stays available so the LLM can look up ids instead of hallucinating. +BUILDER_BLOCKED_TOOLS: tuple[str, ...] = ( + "create_agent", + "customize_agent", + "get_agent_building_guide", +) + + +def resolve_session_permissions( + session: ChatSession | None, +) -> CopilotPermissions | None: + """Blacklist :data:`BUILDER_BLOCKED_TOOLS` for builder-bound sessions, + return ``None`` (unrestricted) otherwise.""" + if session is None or not session.metadata.builder_graph_id: + return None + return CopilotPermissions( + tools=list(BUILDER_BLOCKED_TOOLS), + tools_exclude=True, + ) + + +# Caps — mirror the frontend ``serializeGraphForChat`` defaults so the +# server-side block stays within a practical token budget for large graphs. +_MAX_NODES = 100 +_MAX_LINKS = 200 + +_FETCH_FAILED_PREFIX = ( + f"<{BUILDER_CONTEXT_TAG}>\n" + f"fetch_failed\n" + f"\n\n" +) + +# Embedded in the cacheable suffix so the LLM picks the right run_agent +# dispatch mode without forcing the user to watch a long-blocking call. +_BUILDER_RUN_AGENT_GUIDANCE = ( + "You are operating inside the builder panel, not the standalone " + "copilot page. The builder page already subscribes to agent " + "executions the moment you return an execution_id, so for REAL " + "(non-dry) runs prefer `run_agent(dry_run=False, wait_for_result=0)` " + "— the user will see the run stream in the builder's execution panel " + "in-place and your turn ends immediately with the id. For DRY-RUNS " + "keep `dry_run=True, wait_for_result=120`: blocking is required so " + "you can inspect `execution.node_executions` and report the verdict " + "in the same turn." +) + + +def _sanitize_for_xml(value: Any) -> str: + """Escape XML special chars — mirrors ``sanitizeForXml`` in + ``BuilderChatPanel/helpers.ts``.""" + s = "" if value is None else str(value) + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def _node_display_name(node: dict[str, Any]) -> str: + """Prefer the user-set label (``input_default.name`` / ``metadata.title``); + fall back to the block id.""" + defaults = node.get("input_default") or {} + metadata = node.get("metadata") or {} + for key in ("name", "title", "label"): + value = defaults.get(key) or metadata.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + block_id = node.get("block_id") or "" + return block_id or "unknown" + + +def _format_nodes(nodes: list[dict[str, Any]]) -> str: + if not nodes: + return "\n" + visible = nodes[:_MAX_NODES] + lines = [] + for node in visible: + node_id = _sanitize_for_xml(node.get("id") or "") + name = _sanitize_for_xml(_node_display_name(node)) + block_id = _sanitize_for_xml(node.get("block_id") or "") + lines.append(f"- {node_id}: {name} ({block_id})") + extra = len(nodes) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +def _format_links( + links: list[dict[str, Any]], + nodes: list[dict[str, Any]], +) -> str: + if not links: + return "\n" + name_by_id = {n.get("id"): _node_display_name(n) for n in nodes} + visible = links[:_MAX_LINKS] + lines = [] + for link in visible: + src_id = link.get("source_id") or "" + dst_id = link.get("sink_id") or "" + src_name = name_by_id.get(src_id, src_id) + dst_name = name_by_id.get(dst_id, dst_id) + src_out = link.get("source_name") or "" + dst_in = link.get("sink_name") or "" + lines.append( + f"- {_sanitize_for_xml(src_name)}.{_sanitize_for_xml(src_out)} " + f"-> {_sanitize_for_xml(dst_name)}.{_sanitize_for_xml(dst_in)}" + ) + extra = len(links) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +async def build_builder_system_prompt_suffix(session: ChatSession) -> str: + """Return the cacheable system-prompt suffix for a builder session. + + Holds only static content (dispatch guidance + building guide) so the + bytes are identical across turns AND across sessions for different + graphs — the live id/name/version ride on the per-turn prefix. + """ + if not session.metadata.builder_graph_id: + return "" + + try: + guide = _load_guide() + except Exception: + logger.exception("[builder_context] Failed to load agent-building guide") + return "" + + # The guide is trusted server-side content (read from disk). We do NOT + # escape it — the LLM needs the raw markdown to make sense of block ids, + # code fences, and example JSON. + return ( + f"\n\n<{BUILDER_SESSION_TAG}>\n" + f"\n" + f"{_BUILDER_RUN_AGENT_GUIDANCE}\n" + f"\n" + f"\n{guide}\n\n" + f"" + ) + + +async def build_builder_context_turn_prefix( + session: ChatSession, + user_id: str | None, +) -> str: + """Return the per-turn ```` prefix with the live + graph snapshot (id/name/version/nodes/links). ``""`` for non-builder + sessions; fetch-failure marker if the graph cannot be read.""" + graph_id = session.metadata.builder_graph_id + if not graph_id: + return "" + + try: + agent_json = await get_agent_as_json(graph_id, user_id) + except Exception: + logger.exception( + "[builder_context] Failed to fetch graph %s for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + if not agent_json: + logger.warning( + "[builder_context] Graph %s not found for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + version = _sanitize_for_xml(agent_json.get("version") or "") + raw_name = agent_json.get("name") + graph_name = ( + raw_name.strip() if isinstance(raw_name, str) and raw_name.strip() else None + ) + nodes = agent_json.get("nodes") or [] + links = agent_json.get("links") or [] + name_attr = f' name="{_sanitize_for_xml(graph_name)}"' if graph_name else "" + graph_tag = ( + f'' + ) + + inner = f"{graph_tag}\n{_format_nodes(nodes)}\n{_format_links(links, nodes)}" + return f"<{BUILDER_CONTEXT_TAG}>\n{inner}\n\n\n" diff --git a/autogpt_platform/backend/backend/copilot/builder_context_test.py b/autogpt_platform/backend/backend/copilot/builder_context_test.py new file mode 100644 index 0000000000..efeb6f7dad --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context_test.py @@ -0,0 +1,329 @@ +"""Tests for the split builder-context helpers. + +Covers both halves of the public API: + +- :func:`build_builder_system_prompt_suffix` — session-stable block + appended to the system prompt (contains the guide + graph id/name). +- :func:`build_builder_context_turn_prefix` — per-turn user-message + prefix (contains the live version + node/link snapshot). +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) +from backend.copilot.model import ChatSession + + +def _session( + builder_graph_id: str | None, + *, + user_id: str = "test-user", +) -> ChatSession: + """Minimal ``ChatSession`` with *builder_graph_id* on metadata.""" + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=builder_graph_id, + ) + + +def _agent_json( + nodes: list[dict] | None = None, + links: list[dict] | None = None, + **overrides, +) -> dict: + base: dict = { + "id": "graph-1", + "name": "My Agent", + "description": "A test agent", + "version": 3, + "is_active": True, + "nodes": nodes if nodes is not None else [], + "links": links if links is not None else [], + } + base.update(overrides) + return base + + +# --------------------------------------------------------------------------- +# build_builder_system_prompt_suffix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_system_prompt_suffix(session) + assert result == "" + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_contains_only_static_content(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix.startswith("\n\n") + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert f"" in suffix + assert "" in suffix + assert "# Guide body" in suffix + # Dispatch-mode guidance must appear so the LLM knows to prefer + # wait_for_result=0 for real runs (builder UI subscribes live) and + # wait_for_result=120 for dry-runs (so it can inspect the node trace). + assert "" in suffix + assert "wait_for_result=0" in suffix + assert "wait_for_result=120" in suffix + # Regression: dynamic graph id/name must NOT leak into the cacheable + # suffix — they live in the per-turn prefix so renames and cross-graph + # sessions don't invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "id=" not in suffix + assert "name=" not in suffix + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_identical_across_graphs(): + """The suffix must be byte-identical regardless of which graph the + session is bound to — that's what keeps the cacheable prefix warm + across sessions.""" + s1 = _session("graph-1") + s2 = _session("graph-2", user_id="different-owner") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix_1 = await build_builder_system_prompt_suffix(s1) + suffix_2 = await build_builder_system_prompt_suffix(s2) + + assert suffix_1 == suffix_2 + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_when_guide_load_fails(): + """Guide load failure means we have nothing useful to add — emit an + empty suffix rather than a half-built block.""" + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + side_effect=OSError("missing"), + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix == "" + + +# --------------------------------------------------------------------------- +# build_builder_context_turn_prefix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_turn_prefix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_context_turn_prefix(session, "user-1") + assert result == "" + + +@pytest.mark.asyncio +async def test_turn_prefix_contains_version_nodes_and_links(): + session = _session("graph-1") + nodes = [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + }, + { + "id": "n2", + "block_id": "block-B", + "input_default": {}, + "metadata": {}, + }, + ] + links = [ + { + "source_id": "n1", + "sink_id": "n2", + "source_name": "out", + "sink_name": "in", + } + ] + agent = _agent_json(nodes=nodes, links=links) + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert block.startswith(f"<{BUILDER_CONTEXT_TAG}>\n") + assert block.endswith(f"\n\n") + assert 'id="graph-1"' in block + assert 'name="My Agent"' in block + assert 'version="3"' in block + assert 'node_count="2"' in block + assert 'edge_count="1"' in block + assert "n1: Input (block-A)" in block + assert "n2: block-B (block-B)" in block + assert "Input.out -> block-B.in" in block + + +@pytest.mark.asyncio +async def test_turn_prefix_does_not_include_guide(): + """The guide lives in the cacheable system prompt, not in the per-turn + prefix.""" + session = _session("graph-1") + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json()), + ), + # Sentinel guide text — if it leaks into the turn prefix the + # assertion below catches it. + patch( + "backend.copilot.builder_context._load_guide", + return_value="SENTINEL_GUIDE_BODY", + ), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert "SENTINEL_GUIDE_BODY" not in block + assert "" not in block + + +@pytest.mark.asyncio +async def test_turn_prefix_escapes_graph_name(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json(name='`; + const wrapped = wrapWithHeadInjection( + content, + tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT, + ); + return ( +