Compare commits

..

2 Commits

Author SHA1 Message Date
Nicholas Tindle
001ef93b37 Merge branch 'dev' into fix/open-2895-sheets-missing-credentials 2026-03-27 15:00:45 -05:00
Nicholas Tindle
cd6ddfa7e7 fix(blocks): prevent TypeError when auto_credentials field is empty
When a GoogleDriveFileField input (e.g. spreadsheet) is empty/None, the
executor skips credential injection. The block's run() is then called
without the required `credentials` kwarg, causing a TypeError wrapped as
BlockUnknownError.

Fix in two layers:
- _base.py: inject None for missing auto_credentials kwargs in _execute()
  so all GoogleDriveFileField blocks get a clean fallback
- GoogleSheetsReadBlock: accept optional credentials and guard against None

Fixes OPEN-2895

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:38:19 -05:00
2693 changed files with 826757 additions and 85244 deletions

View File

@@ -1 +0,0 @@
../.claude/skills

View File

@@ -1,10 +0,0 @@
{
"permissions": {
"allowedTools": [
"Read", "Grep", "Glob",
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
]
}
}

View File

@@ -1,545 +0,0 @@
---
name: orchestrate
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
user-invocable: true
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
metadata:
author: autogpt-team
version: "6.0.0"
---
# Orchestrate — Agent Fleet Supervisor
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
## Scripts
```bash
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
STATE_FILE=~/.claude/orchestrator-state.json
```
| Script | Purpose |
|---|---|
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
| `status.sh` | Print fleet status + live pane commands |
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
| `classify-pane.sh WINDOW` | Classify one pane state |
## Supervision model
```
Orchestrating Claude (this Claude session — IS the supervisor)
└── Reads pane output, checks CI, intervenes with targeted guidance
run-loop.sh (separate tmux window, every 30s)
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
```
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
## Checkpoint protocol
Agents output checkpoints as they complete each required step:
```
CHECKPOINT:<step-name>
```
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
## Worktree lifecycle
```text
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
CHECKPOINT:<step> (as steps complete)
ORCHESTRATOR:DONE
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
state → "done", notify, window KEPT OPEN
user/orchestrator explicitly requests recycle
recycle-agent.sh → spare/N (free again)
```
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
**To resume a done or crashed session:**
```bash
# Resume by stored session ID (preferred — exact session, full context)
claude --resume SESSION_ID --permission-mode bypassPermissions
# Or resume most recent session in that worktree directory
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
```
**To manually recycle when ready:**
```bash
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
# Then update state:
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## State file (`~/.claude/orchestrator-state.json`)
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp``mv`).
```json
{
"active": true,
"tmux_session": "autogpt1",
"idle_threshold_seconds": 300,
"loop_window": "autogpt1:5",
"repo": "Significant-Gravitas/AutoGPT",
"discord_webhook": "https://discord.com/api/webhooks/...",
"last_poll_at": 0,
"agents": [
{
"window": "autogpt1:3",
"worktree": "AutoGPT6",
"worktree_path": "/path/to/AutoGPT6",
"spare_branch": "spare/6",
"branch": "feat/my-feature",
"objective": "Implement X and open a PR",
"pr_number": "12345",
"session_id": "550e8400-e29b-41d4-a716-446655440000",
"steps": ["pr-address", "pr-test"],
"checkpoints": ["pr-address"],
"state": "running",
"last_output_hash": "",
"last_seen_at": 0,
"spawned_at": 0,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}
]
}
```
Top-level optional fields:
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
Per-agent fields:
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
## Serial /pr-test rule
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
You (the orchestrating Claude) own the test queue:
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
4. Feed results back to the relevant agent via `tmux send-keys`:
```bash
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
5. Wait for CI to confirm green before marking the agent done.
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
## Session restore (tested and confirmed)
Agent sessions are saved to disk. To restore a closed or crashed session:
```bash
# If session_id is in state (preferred):
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
# If no session_id (use --continue for most recent session in that directory):
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
```
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
```bash
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
'(.agents[] | select(.window == $old)).window = $new' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## Intent → action mapping
Match the user's message to one of these intents:
| The user says something like… | What to do |
|---|---|
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
When the intent is ambiguous, show capacity first and ask what tasks to run.
## Spawning agents
### 1. Resolve tmux session
```bash
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
```
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
### 2. Show available capacity
```bash
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
```
### 3. Collect tasks from the user
For each task, gather:
- **objective** — what to do (e.g. "implement feature X and open a PR")
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
### 4. Spawn one agent per task
```bash
# Get ordered list of spare worktrees
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
# For each task, take the next spare line:
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
# With PR number and required steps:
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
# Without PR (new work):
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
```
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
```bash
# Derive repo from git remote (used by verify-complete.sh + supervisor)
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
jq -n \
--arg session "$SESSION" \
--arg repo "$REPO" \
--argjson threshold 300 \
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
> ~/.claude/orchestrator-state.json
```
Optionally add a Discord webhook for completion notifications:
```bash
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
### 5. Start the mechanical babysitter
```bash
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
### 6. Begin supervising directly in this conversation
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
## Adding an agent
Find the next spare worktree, then spawn and append to state — same as steps 24 above but for a single task. If no spare worktrees are available, tell the user.
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
### Poll loop mechanism
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
1. Start each poll with `run_in_background: true` + a sleep before the work:
```bash
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
# + similar for each active window
```
2. When the background job notifies you, read the pane output and take action.
3. Immediately schedule the next background poll — this keeps the loop alive.
4. Stop scheduling when all agents are done/escalated.
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
### Each poll: what to check
```bash
# 1. Read state
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
# 2. For each running/stuck/idle agent, capture pane
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
```
For each agent, decide:
| What you see | Action |
|---|---|
| Spinner / tools running | Do nothing — agent is working |
| Idle `` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
| Stuck in error loop | Send targeted fix with exact error + solution |
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
| `ORCHESTRATOR:DONE` in output | Run `verify-complete.sh` — if it fails, re-brief with specific reason |
### Strict ORCHESTRATOR:DONE gate
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
```bash
SKILLS_DIR=~/.claude/orchestrator/scripts
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
```
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
### Re-brief a stalled agent
```bash
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
## Self-recovery protocol (agents)
spawn-agent.sh automatically includes this instruction in every objective:
> If your context compacts and you lose track of what to do, run:
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
## Passing images and screenshots to agents
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
1. **Save the image to a temp file** with a stable path:
```bash
# If the user drags in a screenshot or you receive a file path:
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
```
2. **Reference the path in the objective string**:
```bash
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
```
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
---
## Orchestrator final evaluation (YOU decide, not the script)
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
```
- [ ] /pr-test PR #12636 — fix copilot retry logic
- [ ] /pr-test PR #12699 — builder chat panel
```
Run one at a time. Check off as you go.
```
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
```
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
```
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
Context: This PR implements <objective from state file>. Key files: <list>.
Please verify: <specific behaviors to check>.
```
Only one `/pr-test` at a time — they share ports and DB.
### /pr-test result evaluation
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
| `/pr-test` result | Action |
|---|---|
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
| Any headline scenario **FAIL** | Re-brief the agent immediately |
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
**When any headline scenario is PARTIAL or FAIL:**
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
2. Re-brief the agent with the specific scenario that failed and what was missing:
```bash
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
sleep 0.3
tmux send-keys -t SESSION:WIN Enter
```
3. Set state back to `running`:
```bash
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
### 2. Do your own evaluation
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
4. **Check the PR description** — title, summary, test plan complete?
### 3. Decide
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
```bash
# Mark done after your positive evaluation:
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
```
## When to stop the fleet
Stop the fleet (`active = false`) when **all** of the following are true:
| Check | How to verify |
|---|---|
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
| No agents are `escalated` without human review | If any are escalated, surface to user first |
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
```bash
# Graceful stop
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
```
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
## tmux send-keys pattern
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
```bash
# CORRECT — text then Enter separately
tmux send-keys -t "$WINDOW" "your long message here"
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# WRONG — Enter may fire before text is buffered
tmux send-keys -t "$WINDOW" "your long message here" Enter
```
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
---
## Protected worktrees
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
| Worktree | Protected branch | Why |
|---|---|---|
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
---
## Key rules
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
5. **Always `--permission-mode bypassPermissions`** on every spawn
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
7. **Atomic state writes** — always write to `.tmp` then `mv`
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings

View File

@@ -1,43 +0,0 @@
#!/usr/bin/env bash
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
#
# Usage: capacity.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree of the current git repo.
#
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
set -euo pipefail
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
echo "=== Available (spare) worktrees ==="
if [ -n "$REPO_ROOT" ]; then
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
else
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
fi
if [ -z "$SPARE" ]; then
echo " (none)"
else
while IFS= read -r line; do
[ -z "$line" ] && continue
echo "$line"
done <<< "$SPARE"
fi
echo ""
echo "=== In-use worktrees ==="
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
"$STATE_FILE" 2>/dev/null || echo "")
if [ -n "$IN_USE" ]; then
echo "$IN_USE"
else
echo " (none)"
fi
else
echo " (no active state file)"
fi

View File

@@ -1,85 +0,0 @@
#!/usr/bin/env bash
# classify-pane.sh — Classify the current state of a tmux pane
#
# Usage: classify-pane.sh <tmux-target>
# tmux-target: e.g. "work:0", "work:1.0"
#
# Output (stdout): JSON object:
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
#
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
set -euo pipefail
TARGET="${1:-}"
if [ -z "$TARGET" ]; then
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
exit 1
fi
# Validate tmux target format: session:window or session:window.pane
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
exit 1
fi
# Check session exists (use %%:* to extract session name from session:window)
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
exit 1
fi
# Get the current foreground command in the pane
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
| grep -v '^[[:space:]]*$' || true)
# --- Check: explicit completion marker ---
# Must be on its own line (not buried in the objective text sent at spawn time).
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
exit 0
fi
# --- Check: Claude Code approval prompt patterns ---
LAST_40=$(echo "$CLEAN" | tail -40)
APPROVAL_PATTERNS=(
"Do you want to proceed"
"Do you want to make this"
"\\[y/n\\]"
"\\[Y/n\\]"
"\\[n/Y\\]"
"Proceed\\?"
"Allow this command"
"Run bash command"
"Allow bash"
"Would you like"
"Press enter to continue"
"Esc to cancel"
)
for pattern in "${APPROVAL_PATTERNS[@]}"; do
if echo "$LAST_40" | grep -qiE "$pattern"; then
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
exit 0
fi
done
# --- Check: shell prompt (claude has exited) ---
# If the foreground process is a shell (not claude/node), the agent has exited
case "$PANE_CMD" in
zsh|bash|fish|sh|dash|tcsh|ksh)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
exit 0
;;
esac
# Agent is still running (claude/node/python is the foreground process)
jq -n --arg cmd "$PANE_CMD" \
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
exit 0

View File

@@ -1,24 +0,0 @@
#!/usr/bin/env bash
# find-spare.sh — list worktrees on spare/N branches (free to use)
#
# Usage: find-spare.sh [REPO_ROOT]
# REPO_ROOT defaults to the root worktree containing the current git repo.
#
# Output (stdout): one line per available worktree: "PATH BRANCH"
# e.g.: /Users/me/Code/AutoGPT3 spare/3
set -euo pipefail
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
if [ -z "$REPO_ROOT" ]; then
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
exit 1
fi
git -C "$REPO_ROOT" worktree list --porcelain \
| awk '
/^worktree / { path = substr($0, 10) }
/^branch / { branch = substr($0, 8); print path " " branch }
' \
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
| sed 's|refs/heads/||'

View File

@@ -1,40 +0,0 @@
#!/usr/bin/env bash
# notify.sh — send a fleet notification message
#
# Delivery order (first available wins):
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
# 2. macOS notification center — osascript (silent fail if unavailable)
# 3. Stdout only
#
# Usage: notify.sh MESSAGE
# Exit: always 0 (notification failure must not abort the caller)
MESSAGE="${1:-}"
[ -z "$MESSAGE" ] && exit 0
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# --- Resolve Discord webhook ---
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
fi
# --- Discord delivery ---
if [ -n "$WEBHOOK" ]; then
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
curl -s -X POST "$WEBHOOK" \
-H "Content-Type: application/json" \
-d "$PAYLOAD" > /dev/null 2>&1 || true
fi
# --- macOS notification center (silent if not macOS or osascript missing) ---
if command -v osascript &>/dev/null 2>&1; then
# Escape single quotes for AppleScript
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
fi
# Always print to stdout so run-loop.sh logs it
echo "$MESSAGE"
exit 0

View File

@@ -1,257 +0,0 @@
#!/usr/bin/env bash
# poll-cycle.sh — Single orchestrator poll cycle
#
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
# and outputs a JSON array of actions for Claude to take.
#
# Usage: poll-cycle.sh
# Output (stdout): JSON array of action objects
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
# "worktree": "...", "objective": "...", "reason": "..." }]
#
# The state file is updated in-place (atomic write via .tmp).
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
# Cross-platform md5: always outputs just the hex digest
md5_hash() {
if command -v md5sum &>/dev/null; then
md5sum | awk '{print $1}'
else
md5 | awk '{print $NF}'
fi
}
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
# Ensure state file exists
if [ ! -f "$STATE_FILE" ]; then
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
fi
# Validate JSON upfront before any jq reads that run under set -e.
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
# abort the script at the ACTIVE read below without emitting any JSON output.
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
echo "[]"
exit 0
fi
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
if [ "$ACTIVE" != "true" ]; then
echo "[]"
exit 0
fi
NOW=$(date +%s)
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
ACTIONS="[]"
UPDATED_AGENTS="[]"
# Read agents as newline-delimited JSON objects.
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
# so we suppress that exit code and separately validate the file is well-formed JSON.
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
echo "State file parse error — check $STATE_FILE" >&2
fi
echo "[]"
exit 0
fi
if [ -z "$AGENTS_JSON" ]; then
echo "[]"
exit 0
fi
while IFS= read -r agent; do
[ -z "$agent" ] && continue
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
WINDOW=$(echo "$agent" | jq -r '.window // ""')
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
STATE=$(echo "$agent" | jq -r '.state // "running"')
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
# Validate window format to prevent tmux target injection.
# Allow session:window (numeric or named) and session:window.pane
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
echo "Skipping agent with invalid window value: $WINDOW" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Pass-through terminal-state agents
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
fi
# Classify pane.
# classify-pane.sh always emits JSON before exit (even on error), so using
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
# Use "|| true" inside the substitution so set -euo pipefail does not abort
# the poll cycle when classify exits with a non-zero status code.
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
# --- Checkpoint tracking ---
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
if [ -n "$RAW" ]; then
FOUND_CPS=$(echo "$RAW" \
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
| sed 's/CHECKPOINT://' \
| sort -u \
| jq -R . | jq -s . 2>/dev/null || echo "[]")
NEW_CHECKPOINTS_JSON=$(jq -n \
--argjson existing "$EXISTING_CPS" \
--argjson found "$FOUND_CPS" \
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
fi
# Compute content hash for stuck-detection (only for running agents)
CURRENT_HASH=""
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
fi
NEW_STATE="$STATE"
NEW_IDLE_SINCE="$IDLE_SINCE"
NEW_REVISION_COUNT="$REVISION_COUNT"
ACTION="none"
REASON="$PANE_REASON"
case "$PANE_STATE" in
complete)
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
NEW_STATE="pending_evaluation"
ACTION="complete" # run-loop logs it but takes no action
;;
waiting_approval)
NEW_STATE="waiting_approval"
ACTION="approve"
;;
idle)
# Agent process has exited — needs restart
NEW_STATE="idle"
ACTION="kick"
REASON="agent exited (shell is foreground)"
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
fi
;;
running)
# Clear idle_since only when transitioning from idle (agent was kicked and
# restarted). Do NOT reset for stuck — idle_since must persist across polls
# so STUCK_DURATION can accumulate and trigger escalation.
# Also update the local IDLE_SINCE so the hash-stability check below uses
# the reset value on this same poll, not the stale kick timestamp.
if [[ "$STATE" == "idle" ]]; then
NEW_IDLE_SINCE=0
IDLE_SINCE=0
fi
# Check if hash has been stable (agent may be stuck mid-task)
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
NEW_IDLE_SINCE=$NOW
else
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
NEW_IDLE_SINCE=$NOW
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
NEW_STATE="escalated"
ACTION="none"
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
else
NEW_STATE="stuck"
ACTION="kick"
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
fi
fi
fi
else
# Only reset the idle timer when we have a valid hash comparison (pane
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
# preserve existing timers so stuck detection is not inadvertently reset.
if [ -n "$CURRENT_HASH" ]; then
NEW_STATE="running"
NEW_IDLE_SINCE=0
fi
fi
;;
error)
REASON="classify error: $PANE_REASON"
;;
esac
# Build updated agent record (ensure idle_since and revision_count are numeric)
# Use || true on each jq call so a malformed field skips this agent rather than
# aborting the entire poll cycle under set -e.
UPDATED_AGENT=$(echo "$agent" | jq \
--arg state "$NEW_STATE" \
--arg hash "$CURRENT_HASH" \
--argjson now "$NOW" \
--arg idle_since "$NEW_IDLE_SINCE" \
--arg revision_count "$NEW_REVISION_COUNT" \
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
'.state = $state
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
| .last_seen_at = $now
| .idle_since = ($idle_since | tonumber)
| .revision_count = ($revision_count | tonumber)
| .checkpoints = $checkpoints' 2>/dev/null) || {
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
continue
}
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
# Add action if needed
if [ "$ACTION" != "none" ]; then
ACTION_OBJ=$(jq -n \
--arg window "$WINDOW" \
--arg action "$ACTION" \
--arg state "$NEW_STATE" \
--arg worktree "$WORKTREE" \
--arg objective "$OBJECTIVE" \
--arg reason "$REASON" \
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
fi
done <<< "$AGENTS_JSON"
# Atomic state file update
jq --argjson agents "$UPDATED_AGENTS" \
--argjson now "$NOW" \
'.agents = $agents | .last_poll_at = $now' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
echo "$ACTIONS"

View File

@@ -1,32 +0,0 @@
#!/usr/bin/env bash
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
#
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
# WINDOW — tmux target, e.g. autogpt1:3
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — branch to restore, e.g. spare/6
#
# Stdout: one status line
set -euo pipefail
if [ $# -lt 3 ]; then
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
exit 1
fi
WINDOW="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
# Kill the tmux window (ignore error — may already be gone)
tmux kill-window -t "$WINDOW" 2>/dev/null || true
# Restore to spare branch: abort any in-progress operation, then clean
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
echo "Recycled: $(basename "$WORKTREE_PATH")$SPARE_BRANCH (window $WINDOW closed)"

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env bash
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
#
# Handles ONLY two things that need no intelligence:
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
#
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
# marking done, deciding to close windows — is the orchestrating Claude's job.
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
# the orchestrator's background poll loop handles it from there.
#
# Usage: run-loop.sh
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
set -euo pipefail
# Copy scripts to a stable location outside the repo so they survive branch
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
# current branch).
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
mkdir -p "$STABLE_SCRIPTS_DIR"
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
POLL_INTERVAL="${POLL_INTERVAL:-30}"
# ---------------------------------------------------------------------------
# update_state WINDOW FIELD VALUE
# ---------------------------------------------------------------------------
update_state() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --arg v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
update_state_int() {
local window="$1" field="$2" value="$3"
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
}
agent_field() {
jq -r --arg w "$1" --arg f "$2" \
'.agents[] | select(.window == $w) | .[$f] // ""' \
"$STATE_FILE" 2>/dev/null
}
# ---------------------------------------------------------------------------
# wait_for_prompt WINDOW — wait up to 60s for Claude's prompt
# ---------------------------------------------------------------------------
wait_for_prompt() {
local window="$1"
for i in $(seq 1 60); do
local cmd pane
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
if echo "$pane" | grep -q "Enter to confirm"; then
tmux send-keys -t "$window" Down Enter; sleep 2; continue
fi
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "" && return 0
sleep 1
done
return 1 # timed out
}
# ---------------------------------------------------------------------------
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
# ---------------------------------------------------------------------------
handle_kick() {
local window="$1" state="$2"
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
local worktree_path session_id
worktree_path=$(agent_field "$window" "worktree_path")
session_id=$(agent_field "$window" "session_id")
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
# Resume the exact session so the agent retains full context — no need to re-send objective
if [ -n "$session_id" ]; then
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
else
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
fi
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for "
}
# ---------------------------------------------------------------------------
# handle_approve WINDOW — auto-approve dialogs that need no judgment
# ---------------------------------------------------------------------------
handle_approve() {
local window="$1"
local pane_tail
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
# Settings error dialog at startup
if echo "$pane_tail" | grep -q "Enter to confirm"; then
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
tmux send-keys -t "$window" Down Enter
return
fi
# Numbered-option dialog (e.g. "Do you want to make this edit?")
# is already on option 1 (Yes) — Enter confirms it
if echo "$pane_tail" | grep -qE "\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
tmux send-keys -t "$window" "" Enter
return
fi
# y/n prompt for safe operations
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
tmux send-keys -t "$window" "y" Enter
return
fi
# Anything else — supervisor handles it, just log
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
}
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll every ${POLL_INTERVAL}s)"
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
echo "---"
while true; do
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
echo "[$(date +%H:%M:%S)] active=false — exiting."
exit 0
fi
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
KICKED=0; DONE=0
while IFS= read -r action; do
[ -z "$action" ] && continue
WINDOW=$(echo "$action" | jq -r '.window // ""')
ACTION=$(echo "$action" | jq -r '.action // ""')
STATE=$(echo "$action" | jq -r '.state // ""')
case "$ACTION" in
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
approve) handle_approve "$WINDOW" || true ;;
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
esac
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
"$STATE_FILE" 2>/dev/null || echo 0)
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled"
sleep "$POLL_INTERVAL"
done

View File

@@ -1,122 +0,0 @@
#!/usr/bin/env bash
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
#
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
# SESSION — tmux session name, e.g. autogpt1
# WORKTREE_PATH — absolute path to the git worktree
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
# OBJECTIVE — task description sent to the agent
# PR_NUMBER — (optional) GitHub PR number for completion verification
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
#
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
# Exit non-zero on failure.
set -euo pipefail
if [ $# -lt 5 ]; then
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
exit 1
fi
SESSION="$1"
WORKTREE_PATH="$2"
SPARE_BRANCH="$3"
NEW_BRANCH="$4"
OBJECTIVE="$5"
PR_NUMBER="${6:-}"
STEPS=("${@:7}")
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
# Generate a stable session ID so this agent's Claude session can always be resumed:
# claude --resume $SESSION_ID --permission-mode bypassPermissions
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
# Create (or switch to) the task branch
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
# Open a new named tmux window; capture its numeric index
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
WINDOW="${SESSION}:${WIN_IDX}"
# Append the initial agent record to the state file so subsequent jq updates find it.
# This must happen before the pr_number/steps update below.
if [ -f "$STATE_FILE" ]; then
NOW=$(date +%s)
jq --arg window "$WINDOW" \
--arg worktree "$WORKTREE_NAME" \
--arg worktree_path "$WORKTREE_PATH" \
--arg spare_branch "$SPARE_BRANCH" \
--arg branch "$NEW_BRANCH" \
--arg objective "$OBJECTIVE" \
--arg session_id "$SESSION_ID" \
--argjson now "$NOW" \
'.agents += [{
"window": $window,
"worktree": $worktree,
"worktree_path": $worktree_path,
"spare_branch": $spare_branch,
"branch": $branch,
"objective": $objective,
"session_id": $session_id,
"state": "running",
"checkpoints": [],
"last_output_hash": "",
"last_seen_at": $now,
"spawned_at": $now,
"idle_since": 0,
"revision_count": 0,
"last_rebriefed_at": 0
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
# The agent record was appended above so the jq select now finds it.
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
if [ "${#STEPS[@]}" -gt 0 ]; then
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
else
STEPS_JSON='[]'
fi
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
fi
# Launch claude with a stable session ID so it can always be resumed after a crash:
# claude --resume SESSION_ID --permission-mode bypassPermissions
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
# Wait up to 60s for claude to be fully interactive:
# both pane_current_command == 'node' AND the '' prompt is visible.
PROMPT_FOUND=false
for i in $(seq 1 60); do
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "")
PANE=$(tmux capture-pane -t "$WINDOW" -p 2>/dev/null || echo "")
if echo "$PANE" | grep -q "Enter to confirm"; then
tmux send-keys -t "$WINDOW" Down Enter
sleep 2
continue
fi
if [[ "$CMD" == "node" ]] && echo "$PANE" | grep -q ""; then
PROMPT_FOUND=true
break
fi
sleep 1
done
if ! $PROMPT_FOUND; then
echo "[spawn-agent] WARNING: timed out waiting for prompt on $WINDOW — sending objective anyway" >&2
fi
# Send the task. Split text and Enter — if combined, Enter can fire before the string
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
sleep 0.3
tmux send-keys -t "$WINDOW" Enter
# Only output the window address — nothing else (callers parse this)
echo "$WINDOW"

View File

@@ -1,43 +0,0 @@
#!/usr/bin/env bash
# status.sh — print orchestrator status: state file summary + live tmux pane commands
#
# Usage: status.sh
# Reads: ~/.claude/orchestrator-state.json
set -euo pipefail
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
echo "No orchestrator state found at $STATE_FILE"
exit 0
fi
# Header: active status, session, thresholds, last poll
jq -r '
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
""
' "$STATE_FILE"
# Each agent: state, window, worktree/branch, truncated objective
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
if [ "$AGENT_COUNT" -eq 0 ]; then
echo " (no agents registered)"
else
jq -r '
.agents[] |
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
" \(.objective // "" | .[0:70])"
' "$STATE_FILE"
fi
echo ""
# Live pane_current_command for non-done agents
while IFS= read -r WINDOW; do
[ -z "$WINDOW" ] && continue
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
echo " $WINDOW live: $CMD"
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)

View File

@@ -1,180 +0,0 @@
#!/usr/bin/env bash
# verify-complete.sh — verify a PR task is truly done before marking the agent done
#
# Check order matters:
# 1. Checkpoints — did the agent do all required steps?
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
# 3. CI passing — no failures (agent must fix before done)
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
#
# Usage: verify-complete.sh WINDOW
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
set -euo pipefail
WINDOW="$1"
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
# No PR number = cannot verify
if [ -z "$PR_NUMBER" ]; then
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
exit 1
fi
# --- Check 1: all required steps are checkpointed ---
MISSING=""
while IFS= read -r step; do
[ -z "$step" ] && continue
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
MISSING="$MISSING $step"
fi
done <<< "$STEPS"
if [ -n "$MISSING" ]; then
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
exit 1
fi
# Resolve repo for all GitHub checks below
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
fi
if [ -z "$REPO" ]; then
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
exit 0
fi
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
# --- Check 2: CI fully complete — no pending checks ---
# Pending checks MUST finish before we check threads/reviews:
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
if [ "$PENDING" -gt 0 ]; then
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
exit 1
fi
# --- Check 3: CI passing — no failures ---
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
if [ "$FAILING" -gt 0 ]; then
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
exit 1
fi
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
if [ -n "$LATEST_RUN_AT" ]; then
if date --version >/dev/null 2>&1; then
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
else
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
fi
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
exit 1
fi
fi
fi
OWNER=$(echo "$REPO" | cut -d/ -f1)
REPONAME=$(echo "$REPO" | cut -d/ -f2)
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
UNRESOLVED=$(gh api graphql -f query="
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
pullRequest(number: ${PR_NUMBER}) {
reviewThreads(first: 50) { nodes { isResolved } }
}
}
}
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
if [ "$UNRESOLVED" -gt 0 ]; then
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
exit 1
fi
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
# Stale reviews (pre-dating the fixing commits) should not block verification.
#
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
# treat that as NOT COMPLETE rather than silently passing.
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
# PR activity (bot comments, label changes) which would create false negatives.
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
--json commits,latestReviews 2>/dev/null) || {
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
exit 1
}
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
BLOCKING_CHANGES_REQUESTED=0
BLOCKING_REQUESTERS=""
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
if date --version >/dev/null 2>&1; then
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
else
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
fi
while IFS= read -r review; do
[ -z "$review" ] && continue
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
if [ -z "$REVIEW_DATE" ]; then
# No submission date — treat as fresh (conservative: blocks verification)
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
else
if date --version >/dev/null 2>&1; then
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
else
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
fi
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
# Review was submitted AFTER latest commit — still fresh, blocks verification
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
fi
# Review submitted BEFORE latest commit — stale, skip
fi
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
else
# No commit date or no changes_requested — check raw count as fallback
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
fi
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
exit 1
fi
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
exit 0

View File

@@ -90,34 +90,10 @@ Address comments **one at a time**: fix → commit → push → inline reply →
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
Use a **markdown commit link** so GitHub renders it as a clickable reference. Get the full SHA with `git rev-parse HEAD` after committing:
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
## Codecov coverage
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
### Running coverage locally
**Backend** (from `autogpt_platform/backend/`):
```bash
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
```
**Frontend** (from `autogpt_platform/frontend/`):
```bash
pnpm vitest run --coverage
```
### When codecov/patch fails
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
3. Run coverage locally to verify, commit, push.
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
## Format and commit

View File

@@ -547,8 +547,6 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
@@ -586,27 +584,15 @@ TREE_JSON+=']'
# Step 2: Create tree, commit, and branch ref
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
# Resolve parent commit so screenshots are chained, not orphan root commits
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
if [ -n "$PARENT_SHA" ]; then
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
-f "parents[]=$PARENT_SHA" \
--jq '.sha')
else
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
fi
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
-X PATCH -f sha="$COMMIT_SHA" -F force=true
-X PATCH -f sha="$COMMIT_SHA" -f force=true
```
Then post the comment with **inline images AND explanations for each screenshot**:
@@ -672,15 +658,6 @@ INNEREOF
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
rm -f "$COMMENT_FILE"
# Verify the posted comment contains inline images — exit 1 if none found
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
exit 1
fi
echo "✓ Inline images verified in posted comment"
```
**The PR comment MUST include:**
@@ -690,103 +667,6 @@ echo "✓ Inline images verified in posted comment"
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
## Step 8: Evaluate and post a formal PR review
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
### Evaluation criteria
Re-read the PR description:
```bash
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
```
Score the run against each criterion:
| Criterion | Pass condition |
|-----------|---------------|
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
| **All scenarios pass** | No FAIL rows in the results table |
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
| **Before/after evidence** | Every state-changing API call has before/after values logged |
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
| **No regressions** | Existing core flows (login, agent create/run) still work |
### Decision logic
```
ALL criteria pass → APPROVE
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
```
### Post the review
```bash
REVIEW_FILE=$(mktemp)
# Count results
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
# List any coverage gaps found during evaluation (populate this array as you assess)
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
COVERAGE_GAPS=()
```
**If APPROVING** — all criteria met, zero failures, full coverage:
```bash
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — APPROVED
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
**Coverage:** All features described in the PR were exercised.
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
**Negative tests:** Failure paths tested for each feature.
No regressions observed on core flows.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
echo "✅ PR approved"
```
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
```bash
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
cat > "$REVIEW_FILE" <<REVIEWEOF
## E2E Test Evaluation — Changes Requested
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
### Required before merge
${FAIL_LIST}
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
Please fix the above and re-run the E2E tests.
REVIEWEOF
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
echo "❌ Changes requested"
```
```bash
rm -f "$REVIEW_FILE"
```
**Rules:**
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
- Never request changes for issues already fixed in this run
## Fix mode (--fix flag)
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.

View File

@@ -1,224 +0,0 @@
---
name: write-frontend-tests
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
user-invocable: true
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Write Frontend Tests
Analyze the current branch's frontend changes, plan integration tests, and write them.
## References
Before writing any tests, read the testing rules and conventions:
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
## Step 1: Identify changed frontend files
```bash
BASE_BRANCH="${ARGUMENTS:-dev}"
cd autogpt_platform/frontend
# Get changed frontend files (excluding generated, config, and test files)
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
| grep -v '__generated__' \
| grep -v '__tests__' \
| grep -v '\.test\.' \
| grep -v '\.stories\.' \
| grep -v '\.spec\.'
```
Also read the diff to understand what changed:
```bash
git diff "$BASE_BRANCH"...HEAD --stat -- src/
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
```
## Step 2: Categorize changes and find test targets
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Hooks with non-trivial business logic
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
## Step 3: Check for existing tests
For each test target, check if tests already exist:
```bash
# For a page at src/app/(platform)/library/page.tsx
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
```
Note which targets have no tests (need new files) vs which have tests that need updating.
## Step 4: Identify API endpoints used
For each test target, find which API hooks are used:
```bash
# Find generated API hook imports in the changed files
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
```
For each API hook found, locate the corresponding MSW handler:
```bash
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
```
List every MSW handler you will need (200 for happy path, 4xx for error paths).
## Step 5: Write the test plan
Before writing code, output a plan as a numbered list:
```
Test plan for [branch name]:
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
- Renders page with agent list (MSW 200)
- Shows loading state
- Shows error state (MSW 422)
- Handles empty agent list
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
- Filters agents by search query
- Shows no results message
- Clears search
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
- Add test for new "duplicate" action
```
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
## Step 6: Write the tests
For each test file in the plan, follow these conventions:
### File structure
```tsx
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
// Import MSW handlers for endpoints the page uses
import {
getGetV2ListLibraryAgentsMockHandler200,
getGetV2ListLibraryAgentsMockHandler422,
} from "@/app/api/__generated__/endpoints/library/library.msw";
// Import the component under test
import LibraryPage from "../page";
describe("LibraryPage", () => {
test("renders agent list from API", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler200());
render(<LibraryPage />);
expect(await screen.findByText(/my agents/i)).toBeDefined();
});
test("shows error state on API failure", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler422());
render(<LibraryPage />);
expect(await screen.findByText(/error/i)).toBeDefined();
});
});
```
### Rules
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
- Use `server.use()` to set up MSW handlers BEFORE rendering
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
- Use `getBy*` only for elements that are immediately present in the DOM
- Use `screen` queries — do NOT destructure from `render()`
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
### Test location
```
# For pages: __tests__/ next to page.tsx
src/app/(platform)/library/__tests__/main.test.tsx
# For complex standalone components: __tests__/ inside component folder
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
# For pure helpers: co-located .test.ts
src/app/(platform)/library/helpers.test.ts
```
### Custom MSW overrides
When the auto-generated faker data is not enough, override with specific data:
```tsx
import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
);
```
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
## Step 7: Run and verify
After writing all tests:
```bash
cd autogpt_platform/frontend
pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass
Then run the full checks:
```bash
pnpm format
pnpm lint
pnpm types
```

View File

@@ -6,19 +6,11 @@ on:
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -27,22 +19,47 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/original_autogpt
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
- name: Start MinIO service
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -54,23 +71,41 @@ jobs:
git config --global user.name "Auto-GPT-Bot"
git config --global user.email "github-bot@agpt.co"
- name: Set up Python 3.12
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
@@ -81,13 +116,12 @@ jobs:
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
original_autogpt/tests/unit original_autogpt/tests/integration
tests/unit tests/integration
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -101,11 +135,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent
flags: autogpt-agent,${{ runner.os }}
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/logs/
path: classic/original_autogpt/logs/

View File

@@ -148,7 +148,7 @@ jobs:
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
--numprocesses=4 --durations=10 \
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
tests/unit tests/integration 2>&1 | tee test_output.txt
test_failure=${PIPESTATUS[0]}

View File

@@ -10,9 +10,10 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
pull_request:
branches: [ master, dev, release-* ]
@@ -20,9 +21,10 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
defaults:
@@ -33,9 +35,13 @@ defaults:
jobs:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [ original_autogpt ]
fail-fast: false
timeout-minutes: 20
env:
min-python-version: '3.12'
min-python-version: '3.10'
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -49,22 +55,22 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./classic/${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Install dependencies
run: poetry install
- name: Run smoke tests with direct-benchmark
- name: Run regression tests
run: |
poetry run direct-benchmark run \
--strategies one_shot \
--models claude \
--tests ReadFile,WriteFile \
--json
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
poetry run agbenchmark --test=WriteFile
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
AGENT_NAME: ${{ matrix.agent-name }}
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
NONINTERACTIVE_MODE: "true"
CI: true
HELICONE_CACHE_ENABLED: false
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
TELEMETRY_ENVIRONMENT: autogpt-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}

View File

@@ -1,24 +1,18 @@
name: Classic - Direct Benchmark CI
name: Classic - AGBenchmark CI
on:
push:
branches: [ master, dev, ci-test* ]
paths:
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -29,16 +23,23 @@ defaults:
shell: bash
env:
min-python-version: '3.12'
min-python-version: '3.10'
jobs:
benchmark-tests:
runs-on: ubuntu-latest
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/benchmark
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -46,88 +47,71 @@ jobs:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
- name: Install Poetry
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run basic benchmark tests
- name: Run pytest with coverage
run: |
echo "Testing ReadFile challenge with one_shot strategy..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests ReadFile \
--json
echo "Testing WriteFile challenge..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests WriteFile \
--json
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test category filtering
run: |
echo "Testing coding category..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--categories coding \
--tests ReadFile,WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Test multiple strategies
run: |
echo "Testing multiple strategies..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot,plan_execute \
--models claude \
--tests ReadFile \
--parallel 2 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
# Run regression tests on maintain challenges
regression-tests:
self-test-with-agent:
runs-on: ubuntu-latest
timeout-minutes: 45
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
defaults:
run:
shell: bash
working-directory: classic
strategy:
matrix:
agent-name: [forge]
fail-fast: false
timeout-minutes: 20
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -140,31 +124,53 @@ jobs:
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
working-directory: classic
run: |
echo "Running regression tests (previously beaten challenges)..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--maintain \
--parallel 4 \
--json
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
set +e # Ignore non-zero exit codes and continue execution
echo "Running the following command: poetry run agbenchmark --maintain --mock"
poetry run agbenchmark --maintain --mock
EXIT_CODE=$?
set -e # Stop ignoring non-zero exit codes
# Check if the exit code was 5, and if so, exit with 0 instead
if [ $EXIT_CODE -eq 5 ]; then
echo "regression_tests.json is empty."
fi
echo "Running the following command: poetry run agbenchmark --mock"
poetry run agbenchmark --mock
echo "Running the following command: poetry run agbenchmark --mock --category=data"
poetry run agbenchmark --mock --category=data
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
poetry run agbenchmark --mock --category=coding
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
# poetry run agbenchmark --test=WriteFile
cd ../benchmark
poetry install
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
export BUILD_SKILL_TREE=true
# poetry run agbenchmark --mock
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
# if [ ! -z "$CHANGED" ]; then
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
# echo "$CHANGED"
# exit 1
# else
# echo "No unstaged changes."
# fi
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}

View File

@@ -6,15 +6,13 @@ on:
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!classic/forge/tests/vcr_cassettes'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -23,60 +21,131 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
working-directory: classic/forge
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
steps:
- name: Start MinIO service
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python 3.12
- name: Checkout cassettes
if: ${{ startsWith(github.event_name, 'pull_request') }}
env:
PR_BASE: ${{ github.event.pull_request.base.ref }}
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
cassette_base_branch="${PR_BASE}"
cd tests/vcr_cassettes
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
cassette_base_branch="master"
fi
if git ls-remote --exit-code --heads origin $cassette_branch ; then
git fetch origin $cassette_branch
git fetch origin $cassette_base_branch
git checkout $cassette_branch
# Pick non-conflicting cassette updates from the base branch
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
echo "Using cassettes from mirror branch '$cassette_branch'," \
"synced to upstream branch '$cassette_base_branch'."
else
git checkout -b $cassette_branch
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
"Using cassettes from '$cassette_base_branch'."
fi
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Install Playwright browsers
run: poetry run playwright install chromium
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge/forge forge/tests
forge
env:
CI: true
PLAIN_OUTPUT: True
# API keys - tests that need these will skip if not available
# Secrets are not available to fork PRs (GitHub security feature)
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -90,11 +159,85 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge
flags: forge,${{ runner.os }}
- id: setup_git_auth
name: Set up git token authentication
# Cassettes may be pushed even when tests fail
if: success() || failure()
run: |
config_key="http.${{ github.server_url }}/.extraheader"
if [ "${{ runner.os }}" = 'macOS' ]; then
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
else
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
fi
git config "$config_key" \
"Authorization: Basic $base64_pat"
cd tests/vcr_cassettes
git config "$config_key" \
"Authorization: Basic $base64_pat"
echo "config_key=$config_key" >> $GITHUB_OUTPUT
- id: push_cassettes
name: Push updated cassettes
# For pull requests, push updated cassettes even when tests fail
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
env:
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
is_pull_request=true
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
else
cassette_branch="${{ github.ref_name }}"
fi
cd tests/vcr_cassettes
# Commit & push changes to cassettes if any
if ! git diff --quiet; then
git add .
git commit -m "Auto-update cassettes"
git push origin HEAD:$cassette_branch
if [ ! $is_pull_request ]; then
cd ../..
git add tests/vcr_cassettes
git commit -m "Update cassette submodule"
git push origin HEAD:$cassette_branch
fi
echo "updated=true" >> $GITHUB_OUTPUT
else
echo "updated=false" >> $GITHUB_OUTPUT
echo "No cassette changes to commit"
fi
- name: Post Set up git token auth
if: steps.setup_git_auth.outcome == 'success'
run: |
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
- name: Apply "behaviour change" label and comment on PR
if: ${{ startsWith(github.event_name, 'pull_request') }}
run: |
PR_NUMBER="${{ github.event.pull_request.number }}"
TOKEN="${{ secrets.PAT_REVIEW }}"
REPO="${{ github.repository }}"
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
echo "Adding label and comment..."
echo $TOKEN | gh auth login --with-token
gh issue edit $PR_NUMBER --add-label "behaviour change"
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/logs/
path: classic/forge/logs/

View File

@@ -0,0 +1,60 @@
name: Classic - Frontend CI/CD
on:
push:
branches:
- master
- dev
- 'ci-test*' # This will match any branch that starts with "ci-test"
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
pull_request:
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
jobs:
build:
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Setup Flutter
uses: subosito/flutter-action@v2
with:
flutter-version: '3.13.2'
- name: Build Flutter to Web
run: |
cd classic/frontend
flutter build web --base-href /app/
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
# if: github.event_name == 'push'
# run: |
# git config --local user.email "action@github.com"
# git config --local user.name "GitHub Action"
# git add classic/frontend/build/web
# git checkout -B ${{ env.BUILD_BRANCH }}
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
# git push -f origin ${{ env.BUILD_BRANCH }}
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}
branch: ${{ env.BUILD_BRANCH }}
delete-branch: true
title: "Update frontend build in `${{ github.ref_name }}`"
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
commit-message: "Update frontend build based on commit ${{ github.sha }}"

View File

@@ -7,9 +7,7 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
@@ -18,9 +16,7 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- 'classic/benchmark/**'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
@@ -31,13 +27,44 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
original_autogpt:
- classic/original_autogpt/autogpt/**
- classic/original_autogpt/tests/**
- classic/original_autogpt/poetry.lock
forge:
- classic/forge/forge/**
- classic/forge/tests/**
- classic/forge/poetry.lock
benchmark:
- classic/benchmark/agbenchmark/**
- classic/benchmark/tests/**
- classic/benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.12"
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
@@ -54,31 +81,42 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry install
run: poetry -C classic/${{ matrix.sub-package }} install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: classic/${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.12"
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
steps:
- name: Checkout repository
@@ -95,16 +133,19 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry install
run: poetry -C classic/${{ matrix.sub-package }} install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: classic/${{ matrix.sub-package }}

View File

@@ -269,14 +269,12 @@ jobs:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- name: Run pytest with coverage
- name: Run pytest
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
else
poetry run pytest -s -vv \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
poetry run pytest -s -vv
fi
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
@@ -289,13 +287,11 @@ jobs:
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-backend
files: ./autogpt_platform/backend/coverage.xml
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}
env:
CI: true

View File

@@ -148,11 +148,3 @@ jobs:
- name: Run Integration Tests
run: pnpm test:unit
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml

View File

@@ -179,30 +179,21 @@ jobs:
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
# (docker compose config on some versions drops this arg)
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
fi
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
NEXT_PUBLIC_SOURCEMAPS: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
@@ -288,11 +279,6 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
@@ -303,15 +289,6 @@ jobs:
run: pnpm test:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend-e2e
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
disable_search: true
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4

10
.gitignore vendored
View File

@@ -3,7 +3,6 @@
classic/original_autogpt/keys.py
classic/original_autogpt/*.json
auto_gpt_workspace/*
.autogpt/
*.mpeg
.env
# Root .env files
@@ -17,7 +16,6 @@ log-ingestion.txt
/logs
*.log
*.mp3
!autogpt_platform/frontend/public/notification.mp3
mem.sqlite3
venvAutoGPT
@@ -161,10 +159,6 @@ CURRENT_BULLETIN.md
# AgBenchmark
classic/benchmark/agbenchmark/reports/
classic/reports/
classic/direct_benchmark/reports/
classic/.benchmark_workspaces/
classic/direct_benchmark/.benchmark_workspaces/
# Nodejs
package-lock.json
@@ -183,13 +177,9 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
**/.claude/settings.local.json
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
# Test database
test.db
.next
# Implementation plans (generated by AI agents)
plans/

View File

@@ -1,36 +0,0 @@
title = "AutoGPT Gitleaks Config"
[extend]
useDefault = true
[allowlist]
description = "Global allowlist"
paths = [
# Template/example env files (no real secrets)
'''\.env\.(default|example|template)$''',
# Lock files
'''pnpm-lock\.yaml$''',
'''poetry\.lock$''',
# Secrets baseline
'''\.secrets\.baseline$''',
# Build artifacts and caches (should not be committed)
'''__pycache__/''',
'''classic/frontend/build/''',
# Docker dev setup (local dev JWTs/keys only)
'''autogpt_platform/db/docker/''',
# Load test configs (dev JWTs)
'''load-tests/configs/''',
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
'''(_test|test_.*|conftest)\.py$''',
# Documentation (only contains placeholder keys in curl/API examples)
'''docs/.*\.md$''',
# Firebase config (public API keys by design)
'''google-services\.json$''',
'''classic/frontend/(lib|web)/''',
]
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
regexes = [
'''dvziYgz0KSK8FENhju0ZYi8''',
# LLM model name enum values falsely flagged as API keys
'''Llama-\d.*Instruct''',
]

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes

View File

@@ -23,15 +23,9 @@ repos:
- id: detect-secrets
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
args: ["--baseline", ".secrets.baseline"]
files: ^autogpt_platform/
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.3
hooks:
- id: gitleaks
name: Detect secrets (gitleaks)
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
# For proper type checking, all dependencies need to be up-to-date.
@@ -90,16 +84,51 @@ repos:
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic
alias: poetry-install-classic
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
poetry -C classic install
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
# include forge source (since it's a path dependency)
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
language: system
@@ -194,10 +223,26 @@ repos:
language: system
- id: isort
name: Lint (isort) - Classic
alias: isort-classic
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
files: ^classic/(original_autogpt|forge|direct_benchmark)/
name: Lint (isort) - Classic - AutoGPT
alias: isort-classic-autogpt
entry: poetry -P classic/original_autogpt run isort -p autogpt
files: ^classic/original_autogpt/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Forge
alias: isort-classic-forge
entry: poetry -P classic/forge run isort -p forge
files: ^classic/forge/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Benchmark
alias: isort-classic-benchmark
entry: poetry -P classic/benchmark run isort -p agbenchmark
files: ^classic/benchmark/
types: [file, python]
language: system
@@ -211,13 +256,26 @@ repos:
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# Use consolidated flake8 config at classic/.flake8
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
hooks:
- id: flake8
name: Lint (Flake8) - Classic
alias: flake8-classic
files: ^classic/(original_autogpt|forge|direct_benchmark)/
args: [--config=classic/.flake8]
name: Lint (Flake8) - Classic - AutoGPT
alias: flake8-classic-autogpt
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
args: [--config=classic/original_autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Forge
alias: flake8-classic-forge
files: ^classic/forge/(forge|tests)/
args: [--config=classic/forge/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Benchmark
alias: flake8-classic-benchmark
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
- repo: local
hooks:
@@ -253,10 +311,29 @@ repos:
pass_filenames: false
- id: pyright
name: Typecheck - Classic
alias: pyright-classic
entry: poetry -C classic run pyright
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
name: Typecheck - Classic - AutoGPT
alias: pyright-classic-autogpt
entry: poetry -C classic/original_autogpt run pyright
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Forge
alias: pyright-classic-forge
entry: poetry -C classic/forge run pyright
files: ^classic/forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Benchmark
alias: pyright-classic-benchmark
entry: poetry -C classic/benchmark run pyright
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
@@ -283,9 +360,26 @@ repos:
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic (excl. slow tests)
# alias: pytest-classic
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
# name: Run tests - Classic - AutoGPT (excl. slow tests)
# alias: pytest-classic-autogpt
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# # include forge source (since it's a path dependency) but exclude *_test.py files:
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Forge (excl. slow tests)
# alias: pytest-classic-forge
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Benchmark
# alias: pytest-classic-benchmark
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false

View File

@@ -1,467 +0,0 @@
{
"version": "1.5.0",
"plugins_used": [
{
"name": "ArtifactoryDetector"
},
{
"name": "AWSKeyDetector"
},
{
"name": "AzureStorageKeyDetector"
},
{
"name": "Base64HighEntropyString",
"limit": 4.5
},
{
"name": "BasicAuthDetector"
},
{
"name": "CloudantDetector"
},
{
"name": "DiscordBotTokenDetector"
},
{
"name": "GitHubTokenDetector"
},
{
"name": "GitLabTokenDetector"
},
{
"name": "HexHighEntropyString",
"limit": 3.0
},
{
"name": "IbmCloudIamDetector"
},
{
"name": "IbmCosHmacDetector"
},
{
"name": "IPPublicDetector"
},
{
"name": "JwtTokenDetector"
},
{
"name": "KeywordDetector",
"keyword_exclude": ""
},
{
"name": "MailchimpDetector"
},
{
"name": "NpmDetector"
},
{
"name": "OpenAIDetector"
},
{
"name": "PrivateKeyDetector"
},
{
"name": "PypiTokenDetector"
},
{
"name": "SendGridDetector"
},
{
"name": "SlackDetector"
},
{
"name": "SoftlayerDetector"
},
{
"name": "SquareOAuthDetector"
},
{
"name": "StripeDetector"
},
{
"name": "TelegramBotTokenDetector"
},
{
"name": "TwilioKeyDetector"
}
],
"filters_used": [
{
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
},
{
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
"min_level": 2
},
{
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
},
{
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
},
{
"path": "detect_secrets.filters.heuristic.is_lock_file"
},
{
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
},
{
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
},
{
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
},
{
"path": "detect_secrets.filters.heuristic.is_sequential_string"
},
{
"path": "detect_secrets.filters.heuristic.is_swagger_file"
},
{
"path": "detect_secrets.filters.heuristic.is_templated_secret"
},
{
"path": "detect_secrets.filters.regex.should_exclude_file",
"pattern": [
"\\.env$",
"pnpm-lock\\.yaml$",
"\\.env\\.(default|example|template)$",
"__pycache__",
"_test\\.py$",
"test_.*\\.py$",
"conftest\\.py$",
"poetry\\.lock$",
"node_modules"
]
}
],
"results": {
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
"is_verified": false,
"line_number": 289
}
],
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
"is_verified": false,
"line_number": 29
}
],
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
"is_verified": false,
"line_number": 12
}
],
"autogpt_platform/backend/backend/blocks/github/checks.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 108
}
],
"autogpt_platform/backend/backend/blocks/github/ci.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
"is_verified": false,
"line_number": 123
}
],
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
"is_verified": false,
"line_number": 42
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
"is_verified": false,
"line_number": 193
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
"is_verified": false,
"line_number": 344
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
"is_verified": false,
"line_number": 534
}
],
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/backend/backend/blocks/google/docs.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
"is_verified": false,
"line_number": 203
}
],
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
"is_verified": false,
"line_number": 57
}
],
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
"is_verified": false,
"line_number": 53
}
],
"autogpt_platform/backend/backend/blocks/medium.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
"is_verified": false,
"line_number": 131
}
],
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
"is_verified": false,
"line_number": 55
}
],
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
"is_verified": false,
"line_number": 100
}
],
"autogpt_platform/backend/backend/blocks/talking_head.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
"is_verified": false,
"line_number": 113
}
],
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
"is_verified": false,
"line_number": 17
}
],
"autogpt_platform/backend/backend/util/cache.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/util/cache.py",
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
"is_verified": false,
"line_number": 449
}
],
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 6
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 6
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 8
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 5
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 7
}
],
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 192
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
"is_verified": false,
"line_number": 193
}
],
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
"is_verified": false,
"line_number": 102
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 103
}
],
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
"is_verified": false,
"line_number": 73
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
"is_verified": false,
"line_number": 75
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
"is_verified": false,
"line_number": 77
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
"is_verified": false,
"line_number": 79
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
"is_verified": false,
"line_number": 81
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
"is_verified": false,
"line_number": 83
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/frontend/src/lib/constants.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
"is_verified": false,
"line_number": 10
}
],
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
"is_verified": false,
"line_number": 4
}
]
},
"generated_at": "2026-04-02T13:10:54Z"
}

View File

@@ -1,6 +1,6 @@
# AutoGPT Platform Contribution Guide
This guide provides context for coding agents when updating the **autogpt_platform** folder.
This guide provides context for Codex when updating the **autogpt_platform** folder.
## Directory overview
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
@@ -47,9 +47,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
## Testing
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).

View File

@@ -1 +0,0 @@
@AGENTS.md

View File

@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.

View File

@@ -1,120 +0,0 @@
# AutoGPT Platform
This file provides guidance to coding agents when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1 +1,120 @@
@AGENTS.md
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -178,7 +178,6 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -1,227 +0,0 @@
# Backend
This file provides guidance to coding agents when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -1 +1,227 @@
@AGENTS.md
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -31,10 +31,7 @@ from backend.data.model import (
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -621,11 +618,6 @@ async def delete_credential(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id, dry_run=False)
return ChatSession.new(user_id)
@tools_router.post(

View File

@@ -9,14 +9,11 @@ from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
SubscriptionTier,
get_global_rate_limits,
get_usage_status,
get_user_tier,
reset_user_usage,
set_user_tier,
)
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
from backend.data.user import get_user_by_email, get_user_email_by_id
logger = logging.getLogger(__name__)
@@ -36,17 +33,6 @@ class UserRateLimitResponse(BaseModel):
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
tier: SubscriptionTier
class UserTierResponse(BaseModel):
user_id: str
tier: SubscriptionTier
class SetUserTierRequest(BaseModel):
user_id: str
tier: SubscriptionTier
async def _resolve_user_id(
@@ -100,10 +86,10 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
daily_limit, weekly_limit = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
return UserRateLimitResponse(
user_id=resolved_id,
@@ -112,7 +98,6 @@ async def get_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@@ -140,10 +125,10 @@ async def reset_user_rate_limit(
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit, tier = await get_global_rate_limits(
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
try:
resolved_email = await get_user_email_by_id(user_id)
@@ -158,102 +143,4 @@ async def reset_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@router.get(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Get User Rate Limit Tier",
)
async def get_user_rate_limit_tier(
user_id: str,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Get a user's current rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
resolved_email = await get_user_email_by_id(user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
tier = await get_user_tier(user_id)
return UserTierResponse(user_id=user_id, tier=tier)
@router.post(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Set User Rate Limit Tier",
)
async def set_user_rate_limit_tier(
request: SetUserTierRequest,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Set a user's rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
try:
resolved_email = await get_user_email_by_id(request.user_id)
except Exception:
logger.warning(
"Failed to resolve email for user %s",
request.user_id,
exc_info=True,
)
resolved_email = None
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
old_tier = await get_user_tier(request.user_id)
logger.info(
"Admin %s changing tier for user %s (%s): %s -> %s",
admin_user_id,
request.user_id,
resolved_email,
old_tier.value,
request.tier.value,
)
try:
await set_user_tier(request.user_id, request.tier)
except Exception as e:
logger.exception("Failed to set user tier")
raise HTTPException(status_code=500, detail="Failed to set tier") from e
return UserTierResponse(user_id=request.user_id, tier=request.tier)
class UserSearchResult(BaseModel):
user_id: str
user_email: Optional[str] = None
@router.get(
"/rate_limit/search_users",
response_model=list[UserSearchResult],
summary="Search Users by Name or Email",
)
async def admin_search_users(
query: str,
limit: int = 20,
admin_user_id: str = Security(get_user_id),
) -> list[UserSearchResult]:
"""Search users by partial email or name. Admin-only.
Queries the User table directly — returns results even for users
without credit transaction history.
"""
if len(query.strip()) < 3:
raise HTTPException(
status_code=400,
detail="Search query must be at least 3 characters.",
)
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
results = await search_users(query, limit=max(1, min(limit, 50)))
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]

View File

@@ -9,7 +9,7 @@ import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -89,7 +89,6 @@ def test_get_rate_limit(
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
@@ -163,7 +162,6 @@ def test_reset_user_usage_daily_only(
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -194,7 +192,6 @@ def test_reset_user_usage_daily_and_weekly(
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
@@ -231,7 +228,7 @@ def test_get_rate_limit_email_lookup_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -264,303 +261,3 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
json={"user_id": "test"},
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Tier management endpoints
# ---------------------------------------------------------------------------
def test_get_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test getting a user's rate-limit tier."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "PRO"
def test_get_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that getting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 404
def test_set_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test setting a user's rate-limit tier (upgrade)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "ENTERPRISE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to FREE."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "FREE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
def test_set_user_tier_invalid_tier(
target_user_id: str,
) -> None:
"""Test that setting an invalid tier returns 422."""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "invalid"},
)
assert response.status_code == 422
def test_set_user_tier_invalid_tier_uppercase(
target_user_id: str,
) -> None:
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
Regression: ensures Pydantic enum validation rejects values that are not
members of SubscriptionTier, even when they look like valid enum names.
"""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "INVALID"},
)
assert response.status_code == 422
body = response.json()
assert "detail" in body
def test_set_user_tier_email_lookup_failure_returns_404(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that email lookup failure returns 404 (user unverifiable)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection failed"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that setting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_db_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that DB failure on set tier returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
side_effect=Exception("DB connection refused"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 500
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that tier admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": "test", "tier": "PRO"},
)
assert response.status_code == 403
# ─── search_users endpoint ──────────────────────────────────────────
def test_search_users_returns_matching_users(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Partial search should return all matching users from the User table."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[
("user-1", "zamil.majdy@gmail.com"),
("user-2", "zamil.majdy@agpt.co"),
],
)
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
assert response.status_code == 200
results = response.json()
assert len(results) == 2
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
def test_search_users_empty_results(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Search with no matches returns empty list."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
)
assert response.status_code == 200
assert response.json() == []
def test_search_users_short_query_rejected(
admin_user_id: str,
) -> None:
"""Query shorter than 3 characters should return 400."""
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
assert response.status_code == 400
def test_search_users_negative_limit_clamped(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Negative limit should be clamped to 1, not passed through."""
mock_search = mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
)
assert response.status_code == 200
mock_search.assert_awaited_once_with("test", limit=1)
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
"""Test that the search_users endpoint requires admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
assert response.status_code == 403

View File

@@ -11,17 +11,15 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -112,23 +110,6 @@ class StreamChatRequest(BaseModel):
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
mode: CopilotMode | None = Field(
default=None,
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
@@ -137,7 +118,6 @@ class CreateSessionResponse(BaseModel):
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -156,11 +136,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -271,7 +248,6 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -280,28 +256,22 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -397,78 +367,59 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
"""
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in page.messages]
# Only check active stream on initial load (not on "load more" requests)
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=0,
total_completion_tokens=0,
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
total_completion = sum(u.completion_tokens for u in page.session.usage)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
)
@@ -482,9 +433,8 @@ async def get_copilot_usage(
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit, tier = await get_global_rate_limits(
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
@@ -492,7 +442,6 @@ async def get_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
@@ -544,7 +493,7 @@ async def reset_copilot_usage(
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
@@ -578,13 +527,10 @@ async def reset_copilot_usage(
try:
# Verify the user is actually at or over their daily limit.
# (rate_limit_reset_cost intentionally omitted — this object is only
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
@@ -660,7 +606,6 @@ async def reset_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return RateLimitResetResponse(
@@ -771,7 +716,7 @@ async def stream_chat_post(
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit, _ = await get_global_rate_limits(
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
@@ -866,7 +811,6 @@ async def stream_chat_post(
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -1230,7 +1174,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id, dry_run=False)
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -9,7 +9,6 @@ import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
from backend.copilot.rate_limit import SubscriptionTier
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
@@ -332,28 +331,14 @@ def _mock_usage(
*,
daily_used: int = 500,
weekly_used: int = 2000,
daily_limit: int = 10000,
weekly_limit: int = 50000,
tier: "SubscriptionTier" = SubscriptionTier.FREE,
) -> AsyncMock:
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
``get_usage_status`` so that tests exercise the endpoint without hitting
LaunchDarkly or Prisma.
"""
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(daily_limit, weekly_limit, tier),
)
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
@@ -384,7 +369,6 @@ def test_usage_returns_daily_and_weekly(
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
tier=SubscriptionTier.FREE,
)
@@ -392,9 +376,11 @@ def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
response = client.get("/usage")
@@ -405,7 +391,6 @@ def test_usage_uses_config_limits(
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
tier=SubscriptionTier.FREE,
)
@@ -484,98 +469,3 @@ def test_suggested_prompts_empty_prompts(
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422
class TestStreamChatRequestModeValidation:
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
def test_rejects_invalid_mode_value(self) -> None:
"""Any string outside the Literal set must raise ValidationError."""
from pydantic import ValidationError
from backend.api.features.chat.routes import StreamChatRequest
with pytest.raises(ValidationError):
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
def test_accepts_fast_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="fast")
assert req.mode == "fast"
def test_accepts_extended_thinking_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="extended_thinking")
assert req.mode == "extended_thinking"
def test_accepts_none_mode(self) -> None:
"""``mode=None`` is valid (server decides via feature flags)."""
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode=None)
assert req.mode is None
def test_mode_defaults_to_none_when_omitted(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi")
assert req.mode is None

View File

@@ -40,15 +40,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -114,7 +110,6 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -153,7 +148,6 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@@ -230,9 +224,6 @@ async def callback(
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -247,7 +238,6 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -342,11 +332,6 @@ async def delete_credentials(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
@@ -357,11 +342,6 @@ async def delete_credentials(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)

View File

@@ -1,7 +1,6 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import fastapi
import fastapi.testclient
@@ -277,294 +276,3 @@ class TestCreateCredentialNoSecretInResponse:
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -481,11 +481,6 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
},
},
include=library_agent_include(

View File

@@ -12,7 +12,6 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -59,27 +58,14 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture

View File

@@ -1,61 +0,0 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.v1 import v1_router
app = fastapi.FastAPI()
app.include_router(v1_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_onboarding_profile_success(mocker):
mock_extract = mocker.patch(
"backend.api.features.v1.extract_business_understanding",
new_callable=AsyncMock,
)
mock_upsert = mocker.patch(
"backend.api.features.v1.upsert_business_understanding",
new_callable=AsyncMock,
)
from backend.data.understanding import BusinessUnderstandingInput
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
user_name="John",
user_role="Founder/CEO",
pain_points=["Finding leads"],
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
)
mock_upsert.return_value = AsyncMock()
response = client.post(
"/onboarding/profile",
json={
"user_name": "John",
"user_role": "Founder/CEO",
"pain_points": ["Finding leads", "Email & outreach"],
},
)
assert response.status_code == 200
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once()
def test_onboarding_profile_missing_fields():
response = client.post(
"/onboarding/profile",
json={"user_name": "John"},
)
assert response.status_code == 422

View File

@@ -189,7 +189,6 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",

View File

@@ -63,17 +63,12 @@ from backend.data.onboarding import (
UserOnboardingUpdate,
complete_onboarding_step,
complete_re_run_agent,
format_onboarding_for_extraction,
get_recommended_agents,
get_user_onboarding,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
)
from backend.data.tally import extract_business_understanding
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
@@ -287,33 +282,35 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingProfileRequest(pydantic.BaseModel):
"""Request body for onboarding profile submission."""
user_name: str = pydantic.Field(min_length=1, max_length=100)
user_role: str = pydantic.Field(min_length=1, max_length=100)
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding completion check."""
"""Response for onboarding status check."""
is_completed: bool
is_onboarding_enabled: bool
is_chat_enabled: bool
@v1_router.get(
"/onboarding/completed",
summary="Check if onboarding is completed",
"/onboarding/enabled",
summary="Is onboarding enabled",
tags=["onboarding", "public"],
response_model=OnboardingStatusResponse,
dependencies=[Security(requires_user)],
)
async def is_onboarding_completed(
async def is_onboarding_enabled(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
user_onboarding = await get_user_onboarding(user_id)
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
return OnboardingStatusResponse(
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
)
@@ -328,38 +325,6 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await reset_user_onboarding(user_id)
@v1_router.post(
"/onboarding/profile",
summary="Submit onboarding profile",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def submit_onboarding_profile(
data: OnboardingProfileRequest,
user_id: Annotated[str, Security(get_user_id)],
):
formatted = format_onboarding_for_extraction(
user_name=data.user_name,
user_role=data.user_role,
pain_points=data.pain_points,
)
try:
understanding_input = await extract_business_understanding(formatted)
except Exception:
understanding_input = BusinessUnderstandingInput.model_construct()
# Ensure the direct fields are set even if LLM missed them
understanding_input.user_name = data.user_name
understanding_input.user_role = data.user_role
if not understanding_input.pain_points:
understanding_input.pain_points = data.pain_points
await upsert_business_understanding(user_id, understanding_input)
return {"status": "ok"}
########################################################
##################### Blocks ###########################
########################################################

View File

@@ -12,7 +12,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel, Field
from pydantic import BaseModel
from backend.data.workspace import (
WorkspaceFile,
@@ -131,26 +131,9 @@ class StorageUsageResponse(BaseModel):
file_count: int
class WorkspaceFileItem(BaseModel):
id: str
name: str
path: str
mime_type: str
size_bytes: int
metadata: dict = Field(default_factory=dict)
created_at: str
class ListFilesResponse(BaseModel):
files: list[WorkspaceFileItem]
offset: int = 0
has_more: bool = False
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
operation_id="getWorkspaceDownloadFileById",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -175,7 +158,6 @@ async def download_file(
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
operation_id="deleteWorkspaceFile",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -201,7 +183,6 @@ async def delete_workspace_file(
@router.post(
"/files/upload",
summary="Upload file to workspace",
operation_id="uploadWorkspaceFile",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -215,9 +196,6 @@ async def upload_file(
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
# Empty-string session_id drops session scoping; normalize to None.
session_id = session_id or None
config = Config()
# Sanitize filename — strip any directory components
@@ -272,27 +250,16 @@ async def upload_file(
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
content, filename, overwrite=overwrite
)
except ValueError as e:
# write_file raises ValueError for both path-conflict and size-limit
# cases; map each to its correct HTTP status.
message = str(e)
if message.startswith("File too large"):
raise fastapi.HTTPException(status_code=413, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
try:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
except Exception as e:
logger.warning(
f"Failed to soft-delete over-quota file {workspace_file.id} "
f"in workspace {workspace.id}: {e}"
)
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
@@ -314,7 +281,6 @@ async def upload_file(
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
operation_id="getWorkspaceStorageUsage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -335,57 +301,3 @@ async def get_storage_usage(
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)
@router.get(
"/files",
summary="List workspace files",
operation_id="listWorkspaceFiles",
)
async def list_workspace_files(
user_id: Annotated[str, fastapi.Security(get_user_id)],
session_id: str | None = Query(default=None),
limit: int = Query(default=200, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
) -> ListFilesResponse:
"""
List files in the user's workspace.
When session_id is provided, only files for that session are returned.
Otherwise, all files across sessions are listed. Results are paginated
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
"""
workspace = await get_or_create_workspace(user_id)
# Treat empty-string session_id the same as omitted — an empty value
# would otherwise silently list files across every session instead of
# scoping to one.
session_id = session_id or None
manager = WorkspaceManager(user_id, workspace.id, session_id)
include_all = session_id is None
# Fetch one extra to compute has_more without a separate count query.
files = await manager.list_files(
limit=limit + 1,
offset=offset,
include_all_sessions=include_all,
)
has_more = len(files) > limit
page = files[:limit]
return ListFilesResponse(
files=[
WorkspaceFileItem(
id=f.id,
name=f.name,
path=f.path,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
metadata=f.metadata or {},
created_at=f.created_at.isoformat(),
)
for f in page
],
offset=offset,
has_more=has_more,
)

View File

@@ -1,28 +1,48 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace.routes import router
from backend.data.workspace import Workspace, WorkspaceFile
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(router)
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from the REST app."""
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
@@ -33,201 +53,25 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
return Workspace(
id="ws-001",
user_id=user_id,
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
)
def _make_file(**overrides) -> WorkspaceFile:
defaults = {
"id": "file-001",
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "test.txt",
"path": "/test.txt",
"storage_path": "local://test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _make_file_mock(**overrides) -> MagicMock:
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
defaults = {
"id": "file-001",
"name": "test.txt",
"path": "/test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"metadata": {},
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
}
defaults.update(overrides)
mock = MagicMock(spec=WorkspaceFile)
for k, v in defaults.items():
setattr(mock, k, v)
return mock
# -- list_workspace_files tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
files = [
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
data = response.json()
assert len(data["files"]) == 2
assert data["has_more"] is False
assert data["offset"] == 0
assert data["files"][0]["id"] == "f1"
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_scopes_to_session_when_provided(
mock_manager_cls, mock_get_workspace, test_user_id
):
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?session_id=sess-123")
assert response.status_code == 200
data = response.json()
assert data["files"] == []
assert data["has_more"] is False
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=False
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_null_metadata_coerced_to_empty_dict(
mock_manager_cls, mock_get_workspace
):
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
assert response.json()["files"][0]["metadata"] == {}
# -- upload_file metadata tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_passes_user_upload_origin_metadata(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
written = _make_file(id="new-file", name="doc.pdf")
mock_instance = AsyncMock()
mock_instance.write_file.return_value = written
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
)
assert response.status_code == 200
mock_instance.write_file.assert_called_once()
call_kwargs = mock_instance.write_file.call_args
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_returns_409_on_file_conflict(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
mock_instance = AsyncMock()
mock_instance.write_file.side_effect = ValueError("File already exists at path")
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("dup.txt", b"content", "text/plain")},
)
assert response.status_code == 409
assert "already exists" in response.json()["detail"]
# -- Restored upload/download/delete security + invariant tests --
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
_MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-001",
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
name="hello.txt",
path="/sessions/sess-1/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
# ---- Happy path ----
def test_upload_happy_path(mocker):
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -238,7 +82,7 @@ def test_upload_happy_path(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -252,7 +96,10 @@ def test_upload_happy_path(mocker):
assert data["size_bytes"] == 13
def test_upload_exceeds_max_file_size(mocker):
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
@@ -262,11 +109,15 @@ def test_upload_exceeds_max_file_size(mocker):
assert response.status_code == 413
def test_upload_storage_quota_exceeded(mocker):
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
@@ -277,22 +128,27 @@ def test_upload_storage_quota_exceeded(mocker):
assert "Storage limit exceeded" in response.text
def test_upload_post_write_quota_race(mocker):
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024],
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -304,14 +160,17 @@ def test_upload_post_write_quota_race(mocker):
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
def test_upload_any_extension(mocker):
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -322,7 +181,7 @@ def test_upload_any_extension(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -332,13 +191,16 @@ def test_upload_any_extension(mocker):
assert response.status_code == 200
def test_upload_blocked_by_virus_scan(mocker):
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -349,7 +211,7 @@ def test_upload_blocked_by_virus_scan(mocker):
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -357,14 +219,18 @@ def test_upload_blocked_by_virus_scan(mocker):
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
def test_upload_file_without_extension(mocker):
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -375,7 +241,7 @@ def test_upload_file_without_extension(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -391,11 +257,14 @@ def test_upload_file_without_extension(mocker):
assert mock_manager.write_file.call_args[0][1] == "Makefile"
def test_upload_strips_path_components(mocker):
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -406,23 +275,28 @@ def test_upload_strips_path_components(mocker):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
def test_download_file_not_found(mocker):
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
@@ -433,11 +307,14 @@ def test_download_file_not_found(mocker):
assert response.status_code == 404
def test_delete_file_success(mocker):
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
@@ -452,11 +329,11 @@ def test_delete_file_success(mocker):
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker):
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=_make_workspace(),
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
@@ -470,7 +347,7 @@ def test_delete_file_not_found(mocker):
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker):
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
@@ -480,123 +357,3 @@ def test_delete_file_no_workspace(mocker):
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text
def test_upload_write_file_too_large_returns_413(mocker):
"""write_file raises ValueError("File too large: …") → must map to 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 413
assert "File too large" in response.text
def test_upload_write_file_conflict_returns_409(mocker):
"""Non-'File too large' ValueErrors from write_file stay as 409."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 409
assert "already exists" in response.text
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_true_when_limit_exceeded(
mock_manager_cls, mock_get_workspace
):
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
mock_get_workspace.return_value = _make_workspace()
# Backend was asked for limit+1=3, and returned exactly 3 items.
files = [
_make_file(id="f1", name="a.txt"),
_make_file(id="f2", name="b.txt"),
_make_file(id="f3", name="c.txt"),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is True
assert len(data["files"]) == 2
assert data["files"][0]["id"] == "f1"
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=3, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_false_when_exactly_page_size(
mock_manager_cls, mock_get_workspace
):
"""Exactly `limit` rows means we're on the last page — has_more=False."""
mock_get_workspace.return_value = _make_workspace()
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is False
assert len(data["files"]) == 2
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?offset=50&limit=10")
assert response.status_code == 200
assert response.json()["offset"] == 50
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)

View File

@@ -118,11 +118,6 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()

View File

@@ -698,30 +698,20 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
if should_pause:
return
# Validate the input data (original or reviewer-modified) once.
# In dry-run mode, credential fields may contain sentinel None values
# that would fail JSON schema required checks. We still validate the
# non-credential fields so blocks that execute for real during dry-run
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
else:
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Ensure auto_credentials kwargs are present (even as None) so blocks
# that declare e.g. `*, credentials: GoogleCredentials` as a required
# kwarg don't crash with TypeError when the execution context didn't
# resolve credentials (e.g. the GoogleDriveFile field was empty).
for kwarg_name in self.input_schema.get_auto_credentials_fields():
kwargs.setdefault(kwarg_name, None)
# Use the validated input data
async for output_name, output_data in self.run(

View File

@@ -49,17 +49,11 @@ class AgentExecutorBlock(Block):
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
# Check against the nested `inputs` dict, not the top-level node
# data — required fields like "topic" live inside data["inputs"],
# not at data["topic"].
provided = data.get("inputs", {})
return set(required_fields) - set(provided)
return set(required_fields) - set(data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(
cls.get_input_schema(data), data.get("inputs", {})
)
return validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
@@ -94,7 +88,6 @@ class AgentExecutorBlock(Block):
execution_context=execution_context.model_copy(
update={"parent_execution_id": graph_exec_id},
),
dry_run=execution_context.dry_run,
)
logger = execution_utils.LogMetadata(
@@ -156,19 +149,14 @@ class AgentExecutorBlock(Block):
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
]:
logger.info(
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
f"node={getattr(event, 'node_exec_id', '?')}"
logger.debug(
f"Execution {log_id} received event {event.event_type} with status {event.status}"
)
continue
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
logger.info(
f"Execution {log_id} graph completed with status {event.status}, "
f"yielded {len(yielded_node_exec_ids)} outputs"
)
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,

View File

@@ -146,21 +146,6 @@ class AutoPilotBlock(Block):
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -247,11 +232,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
async def create_session(self, user_id: str) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
session = await create_chat_session(user_id, dry_run=dry_run)
session = await create_chat_session(user_id)
return session.session_id
async def execute_copilot(
@@ -382,9 +367,7 @@ class AutoPilotBlock(Block):
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
sid = await self.create_session(execution_context.user_id)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -9,7 +8,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -43,52 +42,8 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
def _make_mime_text(
@@ -145,16 +100,14 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = serialize_email_recipients(input_data.cc)
message["cc"] = ", ".join(input_data.cc)
if input_data.bcc:
message["bcc"] = serialize_email_recipients(input_data.bcc)
message["bcc"] = ", ".join(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1214,15 +1167,13 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1734,16 +1685,13 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = serialize_email_recipients(input_data.to)
msg["To"] = ", ".join(input_data.to)
if input_data.cc:
msg["Cc"] = serialize_email_recipients(input_data.cc)
msg["Cc"] = ", ".join(input_data.cc)
if input_data.bcc:
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -326,11 +326,18 @@ class GoogleSheetsReadBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
self,
input_data: Input,
*,
credentials: GoogleCredentials | None = None,
**kwargs,
) -> BlockOutput:
if not input_data.spreadsheet:
yield "error", "No spreadsheet selected"
return
if not credentials:
yield "error", "Google credentials are required"
return
# Check if the selected file is actually a Google Sheets spreadsheet
validation_error = _validate_spreadsheet_file(input_data.spreadsheet)

View File

@@ -0,0 +1,18 @@
"""Tests for Google Sheets blocks — edge cases around credentials handling."""
import pytest
from backend.blocks.google.sheets import GoogleSheetsReadBlock
from backend.util.exceptions import BlockExecutionError
async def test_sheets_read_no_spreadsheet_yields_clean_error():
"""Executing GoogleSheetsReadBlock with no spreadsheet/credentials should
raise a clean BlockExecutionError, not a BlockUnknownError wrapping a
TypeError about a missing kwarg."""
block = GoogleSheetsReadBlock()
input_data = {"range": "Sheet1!A1:B2"} # no spreadsheet, no credentials
with pytest.raises(BlockExecutionError, match="No spreadsheet selected"):
async for _ in block.execute(input_data):
pass

View File

@@ -2,8 +2,6 @@ import copy
from datetime import date, time
from typing import Any, Optional
from pydantic import AliasChoices, Field
from backend.blocks._base import (
Block,
BlockCategory,
@@ -469,8 +467,7 @@ class AgentFileInputBlock(AgentInputBlock):
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that presents a dropdown selector
restricted to a fixed set of values.
A specialized text input block that relies on placeholder_values to present a dropdown.
"""
class Input(AgentInputBlock.Input):
@@ -480,23 +477,16 @@ class AgentDropdownInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
# Use Field() directly (not SchemaField) to pass validation_alias,
# which handles backward compat for legacy "placeholder_values" across
# all construction paths (model_construct, __init__, model_validate).
options: list = Field(
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default_factory=list,
advanced=False,
title="Dropdown Options",
description=(
"If provided, renders the input as a dropdown selector "
"restricted to these values. Leave empty for free-text input."
),
validation_alias=AliasChoices("options", "placeholder_values"),
json_schema_extra={"advanced": False, "secret": False},
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.options:
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
@@ -514,13 +504,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
{
"value": "Option A",
"name": "dropdown_1",
"options": ["Option A", "Option B", "Option C"],
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"options": ["Option A", "Option B", "Option C"],
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],

View File

@@ -205,19 +205,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Z.ai (Zhipu) models
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
ZAI_GLM_4_5 = "z-ai/glm-4.5"
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
ZAI_GLM_4_6 = "z-ai/glm-4.6"
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
ZAI_GLM_4_7 = "z-ai/glm-4.7"
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
ZAI_GLM_5 = "z-ai/glm-5"
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
@@ -643,43 +630,6 @@ MODEL_METADATA = {
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# https://openrouter.ai/models?q=z-ai
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_6: ModelMetadata(
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7: ModelMetadata(
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_5: ModelMetadata(
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
),
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
@@ -774,9 +724,6 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -792,9 +739,6 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -1028,8 +972,6 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1089,8 +1031,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"OpenRouter returned empty choices: {response}")
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1127,8 +1073,12 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
raise ValueError(f"Llama API returned empty choices: {response}")
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1158,8 +1108,6 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1196,9 +1144,6 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -2066,19 +2011,6 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

View File

@@ -89,12 +89,6 @@ class MCPToolBlock(Block):
default={},
hidden=True,
)
tool_description: str = SchemaField(
description="Description of the selected MCP tool. "
"Populated automatically when a tool is selected.",
default="",
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "

File diff suppressed because it is too large Load Diff

View File

@@ -1,323 +0,0 @@
import asyncio
from typing import Any, Literal
from pydantic import SecretStr
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.sql_query_helpers import (
_DATABASE_TYPE_DEFAULT_PORT,
_DATABASE_TYPE_TO_DRIVER,
DatabaseType,
_execute_query,
_sanitize_error,
_validate_query_is_read_only,
_validate_single_statement,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.request import resolve_and_check_blocked
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="database",
username=SecretStr("test_user"),
password=SecretStr("test_pass"),
title="Mock Database credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
DatabaseCredentials = UserPasswordCredentials
DatabaseCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DATABASE],
Literal["user_password"],
]
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
return CredentialsField(
description="Database username and password",
)
class SQLQueryBlock(Block):
class Input(BlockSchemaInput):
database_type: DatabaseType = SchemaField(
default=DatabaseType.POSTGRES,
description="Database engine",
advanced=False,
)
host: SecretStr = SchemaField(
description=(
"Database hostname or IP address. "
"Treated as a secret to avoid leaking infrastructure details. "
"Private/internal IPs are blocked (SSRF protection)."
),
placeholder="db.example.com",
secret=True,
)
port: int | None = SchemaField(
default=None,
description=(
"Database port (leave empty for default: "
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
),
ge=1,
le=65535,
)
database: str = SchemaField(
description="Name of the database to connect to",
placeholder="my_database",
)
query: str = SchemaField(
description="SQL query to execute",
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
)
read_only: bool = SchemaField(
default=True,
description=(
"When enabled (default), only SELECT queries are allowed "
"and the database session is set to read-only mode. "
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
),
)
timeout: int = SchemaField(
default=30,
description="Query timeout in seconds (max 120)",
ge=1,
le=120,
)
max_rows: int = SchemaField(
default=1000,
description="Maximum number of rows to return (max 10000)",
ge=1,
le=10000,
)
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
class Output(BlockSchemaOutput):
results: list[dict[str, Any]] = SchemaField(
description="Query results as a list of row dictionaries"
)
columns: list[str] = SchemaField(
description="Column names from the query result"
)
row_count: int = SchemaField(description="Number of rows returned")
truncated: bool = SchemaField(
description=(
"True when the result set was capped by max_rows, "
"indicating additional rows exist in the database"
)
)
affected_rows: int = SchemaField(
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
)
error: str = SchemaField(description="Error message if the query failed")
def __init__(self):
super().__init__(
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
description=(
"Execute a SQL query. Read-only by default for safety "
"-- disable to allow write operations. "
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
),
categories={BlockCategory.DATA},
input_schema=SQLQueryBlock.Input,
output_schema=SQLQueryBlock.Output,
test_input={
"query": "SELECT 1 AS test_col",
"database_type": DatabaseType.POSTGRES,
"host": "localhost",
"database": "test_db",
"timeout": 30,
"max_rows": 1000,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("results", [{"test_col": 1}]),
("columns", ["test_col"]),
("row_count", 1),
("truncated", False),
],
test_mock={
"execute_query": lambda *_args, **_kwargs: (
[{"test_col": 1}],
["test_col"],
-1,
False,
),
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
},
)
@staticmethod
async def check_host_allowed(host: str) -> list[str]:
"""Validate that the given host is not a private/blocked address.
Returns the list of resolved IP addresses so the caller can pin the
connection to the validated IP (preventing DNS rebinding / TOCTOU).
Raises ValueError or OSError if the host is blocked.
Extracted as a method so it can be mocked during block tests.
"""
return await resolve_and_check_blocked(host)
@staticmethod
def execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Delegates to ``_execute_query`` in ``sql_query_helpers``.
Extracted as a method so it can be mocked during block tests.
"""
return _execute_query(
connection_url=connection_url,
query=query,
timeout=timeout,
max_rows=max_rows,
read_only=read_only,
database_type=database_type,
)
async def run(
self,
input_data: Input,
*,
credentials: DatabaseCredentials,
**_kwargs: Any,
) -> BlockOutput:
# Validate query structure and read-only constraints.
error = self._validate_query(input_data)
if error:
yield "error", error
return
# Validate host and resolve for SSRF protection.
host, pinned_host, error = await self._resolve_host(input_data)
if error:
yield "error", error
return
# Build connection URL and execute.
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
username = credentials.username.get_secret_value()
connection_url = URL.create(
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
username=username,
password=credentials.password.get_secret_value(),
host=pinned_host,
port=port,
database=input_data.database,
)
conn_str = connection_url.render_as_string(hide_password=True)
db_name = input_data.database
def _sanitize(err: Exception) -> str:
return _sanitize_error(
str(err).strip(),
conn_str,
host=pinned_host,
original_host=host,
username=username,
port=port,
database=db_name,
)
try:
results, columns, affected, truncated = await asyncio.to_thread(
self.execute_query,
connection_url=connection_url,
query=input_data.query,
timeout=input_data.timeout,
max_rows=input_data.max_rows,
read_only=input_data.read_only,
database_type=input_data.database_type,
)
yield "results", results
yield "columns", columns
yield "row_count", len(results)
yield "truncated", truncated
if affected >= 0:
yield "affected_rows", affected
except OperationalError as e:
yield (
"error",
self._classify_operational_error(
_sanitize(e),
input_data.timeout,
),
)
except ProgrammingError as e:
yield "error", f"SQL error: {_sanitize(e)}"
except DBAPIError as e:
yield "error", f"Database error: {_sanitize(e)}"
except ModuleNotFoundError:
yield (
"error",
(
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
),
)
@staticmethod
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
"""Validate query structure and read-only constraints."""
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
return stmt_error
assert parsed_stmt is not None
if input_data.read_only:
return _validate_query_is_read_only(parsed_stmt)
return None
async def _resolve_host(
self, input_data: "SQLQueryBlock.Input"
) -> tuple[str, str, str | None]:
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
host = input_data.host.get_secret_value().strip()
if not host:
return "", "", "Database host is required."
if host.startswith("/"):
return host, "", "Unix socket connections are not allowed."
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
return host, "", f"Blocked host: {str(e).strip()}"
return host, resolved_ips[0], None
@staticmethod
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
"""Classify an already-sanitized OperationalError for user display."""
lower = sanitized_msg.lower()
if "timeout" in lower or "cancel" in lower:
return f"Query timed out after {timeout}s."
if "connect" in lower:
return f"Failed to connect to database: {sanitized_msg}"
return f"Database error: {sanitized_msg}"

File diff suppressed because it is too large Load Diff

View File

@@ -1,430 +0,0 @@
import re
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
import sqlparse
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
"LOAD",
# MySQL REPLACE is INSERT-or-UPDATE; data modification
"REPLACE",
# ANSI MERGE (UPSERT) modifies data
"MERGE",
# MSSQL BULK INSERT loads external files into tables
"BULK",
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
"EXEC",
}
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
# login_timeout). This bounds how long the driver waits to establish a TCP
# connection to the database server. It is separate from the per-statement
# timeout configured via SET commands inside _configure_session().
_CONNECT_TIMEOUT_SECONDS = 10
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
database: str = "",
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
- The database name
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
# Replace the database name to avoid leaking internal infrastructure names.
# Use word-boundary regex to prevent mangling when the database name is a
# common substring (e.g. "test", "data", "on").
if database:
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
return [
token.normalized.upper()
for token in parsed.flatten()
if token.ttype
in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
)
]
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
"""Check if a statement contains a disallowed ``INTO`` clause.
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
a query result into a session-scoped user variable. All other forms of
``INTO`` are data-modifying or file-writing and must be blocked:
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL creates a table)
* ``SELECT ... INTO OUTFILE`` (MySQL writes to the filesystem)
* ``SELECT ... INTO DUMPFILE`` (MySQL writes to the filesystem)
* ``INSERT INTO ...`` (already blocked by INSERT being in the
disallowed set, but we reject INTO as well for defense-in-depth)
Returns ``True`` if the statement contains a disallowed ``INTO``.
"""
flat = list(stmt.flatten())
for i, token in enumerate(flat):
if not (
token.ttype in (sqlparse.tokens.Keyword,)
and token.normalized.upper() == "INTO"
):
continue
# Look at the first non-whitespace token after INTO.
j = i + 1
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
j += 1
if j >= len(flat):
# INTO at the very end malformed, block it.
return True
next_token = flat[j]
# MySQL user variable: either a single Name starting with "@"
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
"@"
):
continue
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
continue
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
return True
return False
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
for kw in _extract_keyword_tokens(stmt):
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
if _has_disallowed_into(stmt):
return "Disallowed SQL keyword: INTO"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# NaN / Infinity are not valid JSON numbers; serialize as strings.
if value.is_nan() or value.is_infinite():
return str(value)
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, memoryview):
return bytes(value).hex()
if isinstance(value, bytes):
return value.hex()
return value
def _configure_session(
conn: Any,
dialect_name: str,
timeout_ms: str,
read_only: bool,
) -> None:
"""Set session-level timeout and read-only mode for the given dialect.
Timeout limitations by database:
* **PostgreSQL** ``statement_timeout`` reliably cancels any running
statement (SELECT or DML) after the configured duration.
* **MySQL** ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
this hint; they rely on the server's ``wait_timeout`` /
``interactive_timeout`` instead. There is no session-level setting in
MySQL that reliably cancels long-running writes.
* **MSSQL** ``SET LOCK_TIMEOUT`` only limits how long the server waits
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
joins) that do not block on locks will *not* be cancelled. MSSQL has
no session-level ``statement_timeout`` equivalent; the closest
mechanism is Resource Governor (requires sysadmin configuration) or
``CONTEXT_INFO``-based external monitoring.
Note: SQLite is not supported by this block. The ``_configure_session``
function is a no-op for unrecognised dialect names, so an SQLite engine
would skip all SET commands silently. The block's ``DatabaseType`` enum
intentionally excludes SQLite.
"""
if dialect_name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif dialect_name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
# See docstring above for full limitations.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif dialect_name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
# CPU-bound queries without lock contention are NOT cancelled.
# See docstring above for full limitations.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
def _run_in_transaction(
conn: Any,
dialect_name: str,
query: str,
max_rows: int,
read_only: bool,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a query inside an explicit transaction, returning results.
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
indicating that additional rows may exist in the result set.
"""
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
conn.execute(text(begin_stmt))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
truncated = len(rows) == max_rows
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
except Exception:
try:
conn.execute(text("ROLLBACK"))
except Exception:
pass
raise
else:
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
return results, columns, affected, truncated
def _execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
timeout_key = (
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
)
engine = create_engine(
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
)
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Compute timeout in milliseconds. The value is Pydantic-validated
# (ge=1, le=120), but we use int() as defense-in-depth.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
return _run_in_transaction(
conn, engine.dialect.name, query, max_rows, read_only
)
finally:
engine.dispose()

View File

@@ -300,27 +300,13 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
using the canonical 'options' field name."""
opts = ["Option A", "Option B"]
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
options = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, options=opts
name="choice", value=None, placeholder_values=options
)
schema = instance.generate_schema()
assert schema.get("enum") == opts
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
"""Verify backward compat: passing legacy 'placeholder_values' to
AgentDropdownInputBlock still produces enum via model_construct remap."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=opts
)
schema = instance.generate_schema()
assert (
schema.get("enum") == opts
), "Legacy placeholder_values should be remapped to options"
assert schema.get("enum") == options
def test_generate_schema_integration_legacy_placeholder_values():
@@ -343,11 +329,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks using canonical field name."""
— verifies enum IS produced for dropdown blocks."""
dropdown_input_default = {
"name": "color",
"value": None,
"options": ["Red", "Green", "Blue"],
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
@@ -358,36 +344,3 @@ def test_generate_schema_integration_dropdown_produces_enum():
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
using legacy 'placeholder_values' — verifies backward compat produces enum."""
legacy_dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Legacy placeholder_values should still produce enum via model_construct remap"
def test_dropdown_input_block_init_legacy_placeholder_values():
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_validate(
{"name": "choice", "value": None, "placeholder_values": opts}
)
assert (
instance.options == opts
), "Legacy placeholder_values should be remapped to options via model_validate"
schema = instance.generate_schema()
assert schema.get("enum") == opts

View File

@@ -488,154 +488,6 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""

View File

@@ -1,87 +0,0 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,7 +1074,6 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1106,7 +1105,6 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -1,202 +0,0 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

File diff suppressed because it is too large Load Diff

View File

@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id, dry_run=False)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -1,633 +0,0 @@
"""Unit tests for baseline service pure-logic helpers.
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, patch
import pytest
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.baseline.service import (
_baseline_conversation_updater,
_BaselineStreamState,
_compress_session_messages,
_ThinkingStripper,
)
from backend.copilot.model import ChatMessage
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.prompt import CompressResult
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
class TestBaselineStreamState:
def test_defaults(self):
state = _BaselineStreamState()
assert state.pending_events == []
assert state.assistant_text == ""
assert state.text_started is False
assert state.turn_prompt_tokens == 0
assert state.turn_completion_tokens == 0
assert state.text_block_id # Should be a UUID string
def test_mutable_fields(self):
state = _BaselineStreamState()
state.assistant_text = "hello"
state.turn_prompt_tokens = 100
state.turn_completion_tokens = 50
assert state.assistant_text == "hello"
assert state.turn_prompt_tokens == 100
assert state.turn_completion_tokens == 50
class TestBaselineConversationUpdater:
"""Tests for _baseline_conversation_updater which updates the OpenAI
message list and transcript builder after each LLM call."""
def _make_transcript_builder(self) -> TranscriptBuilder:
builder = TranscriptBuilder()
builder.append_user("test question")
return builder
def test_text_only_response(self):
"""When the LLM returns text without tool calls, the updater appends
a single assistant message and records it in the transcript."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Hello, world!",
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 1
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Hello, world!"
# Transcript should have user + assistant
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
def test_tool_calls_response(self):
"""When the LLM returns tool calls, the updater appends the assistant
message with tool_calls and tool result messages."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Let me search...",
tool_calls=[
LLMToolCall(
id="tc_1",
name="search",
arguments='{"query": "test"}',
),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(
tool_call_id="tc_1",
tool_name="search",
content="Found result",
),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Messages: assistant (with tool_calls) + tool result
assert len(messages) == 2
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Let me search..."
assert len(messages[0]["tool_calls"]) == 1
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
assert messages[1]["role"] == "tool"
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[1]["content"] == "Found result"
# Transcript: user + assistant(tool_use) + user(tool_result)
assert builder.entry_count == 3
def test_tool_calls_without_text(self):
"""Tool calls without accompanying text should still work."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="run", arguments="{}"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 2
assert "content" not in messages[0] # No text content
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
def test_no_text_no_tools(self):
"""When the response has no text and no tool calls, nothing is appended."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 0
# Only the user entry from setup
assert builder.entry_count == 1
def test_multiple_tool_calls(self):
"""Multiple tool calls in a single response are all recorded."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# 1 assistant + 2 tool results
assert len(messages) == 3
assert len(messages[0]["tool_calls"]) == 2
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[2]["tool_call_id"] == "tc_2"
def test_invalid_tool_arguments_handled(self):
"""Tool call with invalid JSON arguments: the arguments field is
stored as-is in the message, and orjson failure falls back to {}
in the transcript content_blocks."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Should not raise — invalid JSON falls back to {} in transcript
assert len(messages) == 2
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
class TestCompressSessionMessagesPreservesToolCalls:
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
Compression serialises ChatMessage to dict for ``compress_context`` and
reifies the result back to ChatMessage. A regression that drops
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
list and break downstream tool-execution rounds.
"""
@pytest.mark.asyncio
async def test_compressed_output_keeps_tool_calls_and_ids(self):
# Simulate compression that returns a summary + the most recent
# assistant(tool_call) + tool(tool_result) intact.
summary = {"role": "system", "content": "prior turns: user asked X"}
assistant_with_tc = {
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "tc_abc",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"y"}'},
}
],
}
tool_result = {
"role": "tool",
"tool_call_id": "tc_abc",
"content": "search result",
}
compress_result = CompressResult(
messages=[summary, assistant_with_tc, tool_result],
token_count=100,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=0,
)
# Input: messages that should be compressed.
input_messages = [
ChatMessage(role="user", content="q1"),
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "tc_abc",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q":"y"}',
},
}
],
),
ChatMessage(
role="tool",
tool_call_id="tc_abc",
content="search result",
),
]
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=compress_result),
):
compressed = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
# Summary, assistant(tool_calls), tool(tool_call_id).
assert len(compressed) == 3
# Assistant message must keep its tool_calls intact.
assistant_msg = compressed[1]
assert assistant_msg.role == "assistant"
assert assistant_msg.tool_calls is not None
assert len(assistant_msg.tool_calls) == 1
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
# Tool-role message must keep tool_call_id for OpenAI linkage.
tool_msg = compressed[2]
assert tool_msg.role == "tool"
assert tool_msg.tool_call_id == "tc_abc"
assert tool_msg.content == "search result"
@pytest.mark.asyncio
async def test_uncompressed_passthrough_keeps_fields(self):
"""When compression is a no-op (was_compacted=False), the original
messages must be returned unchanged — including tool_calls."""
input_messages = [
ChatMessage(
role="assistant",
content="c",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
]
noop_result = CompressResult(
messages=[], # ignored when was_compacted=False
token_count=10,
was_compacted=False,
)
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=noop_result),
):
out = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
assert out is input_messages # same list returned
assert out[0].tool_calls is not None
assert out[0].tool_calls[0]["id"] == "t1"
assert out[1].tool_call_id == "t1"
# ---- _ThinkingStripper tests ---- #
def test_thinking_stripper_basic_thinking_tag() -> None:
"""<thinking>...</thinking> blocks are fully stripped."""
s = _ThinkingStripper()
assert s.process("<thinking>internal reasoning here</thinking>Hello!") == "Hello!"
def test_thinking_stripper_internal_reasoning_tag() -> None:
"""<internal_reasoning>...</internal_reasoning> blocks (Gemini) are stripped."""
s = _ThinkingStripper()
assert (
s.process("<internal_reasoning>step by step</internal_reasoning>Answer")
== "Answer"
)
def test_thinking_stripper_split_across_chunks() -> None:
"""Tags split across multiple chunks are handled correctly."""
s = _ThinkingStripper()
out = s.process("Hello <thin")
out += s.process("king>secret</thinking> world")
assert out == "Hello world"
def test_thinking_stripper_plain_text_preserved() -> None:
"""Plain text with the word 'thinking' is not stripped."""
s = _ThinkingStripper()
assert (
s.process("I am thinking about this problem")
== "I am thinking about this problem"
)
def test_thinking_stripper_multiple_blocks() -> None:
"""Multiple reasoning blocks in one stream are all stripped."""
s = _ThinkingStripper()
result = s.process(
"A<thinking>x</thinking>B<internal_reasoning>y</internal_reasoning>C"
)
assert result == "ABC"
def test_thinking_stripper_flush_discards_unclosed() -> None:
"""Unclosed reasoning block is discarded on flush."""
s = _ThinkingStripper()
s.process("Start<thinking>never closed")
flushed = s.flush()
assert "never closed" not in flushed
def test_thinking_stripper_empty_block() -> None:
"""Empty reasoning blocks are handled gracefully."""
s = _ThinkingStripper()
assert s.process("Before<thinking></thinking>After") == "BeforeAfter"
# ---- _filter_tools_by_permissions tests ---- #
def _make_tool(name: str) -> ChatCompletionToolParam:
"""Build a minimal OpenAI ChatCompletionToolParam."""
return ChatCompletionToolParam(
type="function",
function={"name": name, "parameters": {}},
)
class TestFilterToolsByPermissions:
"""Tests for _filter_tools_by_permissions."""
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_empty_permissions_returns_all(self, _mock_names):
"""Empty permissions (no filtering) returns every tool unchanged."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [_make_tool("run_block"), _make_tool("web_fetch")]
perms = CopilotPermissions()
result = _filter_tools_by_permissions(tools, perms)
assert result == tools
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_allowlist_keeps_only_matching(self, _mock_names):
"""Explicit allowlist (tools_exclude=False) keeps only listed tools."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [
_make_tool("run_block"),
_make_tool("web_fetch"),
_make_tool("bash_exec"),
]
perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False)
result = _filter_tools_by_permissions(tools, perms)
assert len(result) == 1
assert result[0]["function"]["name"] == "web_fetch"
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_blacklist_excludes_listed(self, _mock_names):
"""Blacklist (tools_exclude=True) removes only the listed tools."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [
_make_tool("run_block"),
_make_tool("web_fetch"),
_make_tool("bash_exec"),
]
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
result = _filter_tools_by_permissions(tools, perms)
names = [t["function"]["name"] for t in result]
assert "bash_exec" not in names
assert "run_block" in names
assert "web_fetch" in names
assert len(result) == 2
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_unknown_tool_name_filtered_out(self, _mock_names):
"""A tool whose name is not in all_known_tool_names is dropped."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [_make_tool("run_block"), _make_tool("unknown_tool")]
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
result = _filter_tools_by_permissions(tools, perms)
names = [t["function"]["name"] for t in result]
assert "unknown_tool" not in names
assert names == ["run_block"]
# ---- _prepare_baseline_attachments tests ---- #
class TestPrepareBaselineAttachments:
"""Tests for _prepare_baseline_attachments."""
@pytest.mark.asyncio
async def test_empty_file_ids(self):
"""Empty file_ids returns empty hint and blocks."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp")
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_empty_user_id(self):
"""Empty user_id returns empty hint and blocks."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
hint, blocks = await _prepare_baseline_attachments(
["file1"], "", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_image_file_returns_vision_blocks(self):
"""A PNG image within size limits is returned as a base64 vision block."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_info = AsyncMock()
fake_info.name = "photo.png"
fake_info.mime_type = "image/png"
fake_info.size_bytes = 1024
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA")
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", "/tmp/workdir"
)
assert len(blocks) == 1
assert blocks[0]["type"] == "image"
assert blocks[0]["source"]["media_type"] == "image/png"
assert blocks[0]["source"]["type"] == "base64"
assert "photo.png" in hint
assert "embedded as image" in hint
@pytest.mark.asyncio
async def test_non_image_file_saved_to_working_dir(self, tmp_path):
"""A non-image file is written to working_dir."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_info = AsyncMock()
fake_info.name = "data.csv"
fake_info.mime_type = "text/csv"
fake_info.size_bytes = 42
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b")
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", str(tmp_path)
)
assert blocks == []
assert "data.csv" in hint
assert "saved to" in hint
saved = tmp_path / "data.csv"
assert saved.exists()
assert saved.read_bytes() == b"col1,col2\na,b"
@pytest.mark.asyncio
async def test_file_not_found_skipped(self):
"""When get_file_info returns None the file is silently skipped."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["missing_id"], "user1", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_workspace_manager_error(self):
"""When get_workspace_manager raises, returns empty results."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(side_effect=RuntimeError("connection failed")),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []

View File

@@ -1,667 +0,0 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
import json as stdlib_json
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
def _make_transcript_content(*roles: str) -> str:
"""Build a minimal valid JSONL transcript from role names."""
lines = []
parent = ""
for i, role in enumerate(roles):
uid = f"uuid-{i}"
entry: dict = {
"type": role,
"uuid": uid,
"parentUuid": parent,
"message": {
"role": role,
"content": [{"type": "text", "text": f"{role} message {i}"}],
},
}
if role == "assistant":
entry["message"]["id"] = f"msg_{i}"
entry["message"]["model"] = "test-model"
entry["message"]["type"] = "message"
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
lines.append(stdlib_json.dumps(entry))
parent = uid
return "\n".join(lines) + "\n"
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
assert _resolve_baseline_model(None) == config.model
def test_default_and_fast_models_differ(self):
"""Sanity: the two tiers are actually distinct in production config."""
assert config.model != config.fast_model
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
message_count=1,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_download_exception_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=0,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
class TestUploadFinalTranscript:
"""``_upload_final_transcript`` serialises and calls storage."""
@pytest.mark.asyncio
async def test_uploads_valid_transcript(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
call_kwargs = upload_mock.await_args.kwargs
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=0,
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_swallows_upload_exceptions(self):
"""Upload failures should not propagate (flow continues for the user)."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
):
# Should not raise.
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
class TestRecordTurnToTranscript:
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
def test_records_final_assistant_text(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text="hello there",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
jsonl = builder.to_jsonl()
assert "hello there" in jsonl
assert STOP_REASON_END_TURN in jsonl
def test_records_tool_use_then_tool_result(self):
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
builder = TranscriptBuilder()
builder.append_user(content="use a tool")
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
# user, assistant(tool_use), user(tool_result) = 3 entries
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert STOP_REASON_TOOL_USE in jsonl
assert "tool_use" in jsonl
assert "tool_result" in jsonl
assert "call-1" in jsonl
def test_records_nothing_on_empty_response(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 1
def test_malformed_tool_args_dont_crash(self):
"""Bad JSON in tool arguments falls back to {} without raising."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
class TestRoundTrip:
"""End-to-end: load prior → append new turn → upload."""
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
# New user turn.
builder.append_user(content="new question")
assert builder.entry_count == 3
# New assistant turn.
response = LLMLoopResponse(
response_text="new answer",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 4
# Upload.
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
"""Backfill only runs when the last entry is not already assistant."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
# Simulate the backfill guard from stream_chat_completion_baseline.
assistant_text = "partial text before error"
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": assistant_text}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.last_entry_type == "assistant"
assert "partial text before error" in builder.to_jsonl()
# Second invocation: the guard must prevent double-append.
initial_count = builder.entry_count
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": "duplicate"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
def test_upload_allowed_for_user_with_coverage(self):
assert should_upload_transcript("user-1", True) is True
def test_upload_skipped_when_no_user(self):
assert should_upload_transcript(None, True) is False
def test_upload_skipped_when_empty_user(self):
assert should_upload_transcript("", True) is False
def test_upload_skipped_without_coverage(self):
"""Partial transcript must never clobber a more complete stored one."""
assert should_upload_transcript("user-1", False) is False
def test_upload_skipped_when_no_user_and_no_coverage(self):
assert should_upload_transcript(None, False) is False
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
driving each step through the real helpers.
"""
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
# --- 2. Append a new user turn + a new assistant response ---
builder.append_user(content="follow-up question")
_record_turn_to_transcript(
LLMLoopResponse(
response_text="follow-up answer",
tool_calls=[],
raw_response=None,
),
tool_results=None,
transcript_builder=builder,
model="test-model",
)
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
)
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=2,
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=stale),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
"""Anonymous (user_id=None) → upload gate must return False."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
)
upload_mock.assert_not_awaited()

View File

@@ -8,35 +8,18 @@ from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
# Per-request routing mode for a single chat turn.
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
# - 'extended_thinking': route to the Claude Agent SDK path with the default
# (opus) model.
# ``None`` means "no override"; the server falls back to the Claude Code
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.6",
description="Default model for extended thinking mode",
)
fast_model: str = Field(
default="anthropic/claude-sonnet-4",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
default="anthropic/claude-opus-4.6", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
simulation_model: str = Field(
default="google/gemini-2.5-flash",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
@@ -94,11 +77,11 @@ class ChatConfig(BaseSettings):
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# ENTERPRISE) multiply these by their tier multiplier (see
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# User.subscriptionTier DB column and resolved inside
# get_global_rate_limits().
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
@@ -195,7 +178,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``build_sdk_env``.
present — mirrors the fallback logic in ``_build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -149,8 +149,7 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
or ``tool-outputs/...``.
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
"""
@@ -175,20 +174,17 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and the tool directory.
# Accept both "tool-results" (SDK's persisted outputs) and
# "tool-outputs" (the model sometimes confuses workspace paths
# with filesystem paths and generates this variant).
# the project dir and tool-results/.
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
# Require exactly: [<uuid>, "tool-results", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] in ("tool-results", "tool-outputs")
and parts[1] == "tool-results"
):
return True

View File

@@ -134,21 +134,6 @@ def test_is_allowed_local_path_tool_results_with_uuid():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_outputs_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
encoded = "test-encoded-dir"
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
)
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
@@ -174,7 +159,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(

View File

@@ -14,32 +14,15 @@ from prisma.types import (
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import (
ChatMessage,
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
cache_chat_session,
)
from .model import get_chat_session as get_chat_session_cached
from .model import ChatMessage, ChatSession, ChatSessionInfo
logger = logging.getLogger(__name__)
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
session: ChatSessionInfo
async def get_chat_session(session_id: str) -> ChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
@@ -49,120 +32,9 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
return ChatSession.from_db(session) if session else None
async def get_chat_session_metadata(session_id: str) -> ChatSessionInfo | None:
"""Get chat session metadata (without messages) for ownership validation."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
)
return ChatSessionInfo.from_db(session) if session else None
async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session, newest first.
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
if before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
session = await PrismaChatSession.prisma().find_first(
where=session_where,
include={"Messages": msg_include},
)
if session is None:
return None
session_info = ChatSessionInfo.from_db(session)
results = list(session.Messages) if session.Messages else []
has_more = len(results) > limit
results = results[:limit]
# Reverse to ascending order
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
session=session_info,
)
async def create_chat_session(
session_id: str,
user_id: str,
metadata: ChatSessionMetadata | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -171,7 +43,6 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -186,12 +57,7 @@ async def update_chat_session(
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
"""Update a chat session's mutable fields.
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
it is set once at creation time and treated as immutable for the lifetime
of the session.
"""
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
@@ -351,9 +217,6 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
if msg.get("duration_ms") is not None:
data["durationMs"] = msg["duration_ms"]
messages_data.append(data)
# Run create_many and session update in parallel within transaction
@@ -496,33 +359,3 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Updates the Redis cache in-place instead of invalidating it.
Invalidation would delete the key, creating a window where concurrent
``get_chat_session`` calls re-populate the cache from DB — potentially
with stale data if the DB write from the previous turn hasn't propagated.
This race caused duplicate user messages on the next turn.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
order={"sequence": "desc"},
)
if last_msg:
await PrismaChatMessage.prisma().update(
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Update cache in-place rather than invalidating to avoid a
# race window where the empty cache gets re-populated with
# stale data by a concurrent get_chat_session call.
session = await get_chat_session_cached(session_id)
if session and session.messages:
for msg in reversed(session.messages):
if msg.role == "assistant":
msg.duration_ms = duration_ms
break
await cache_chat_session(session)

View File

@@ -1,388 +0,0 @@
"""Unit tests for copilot.db — paginated message queries."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from backend.copilot.db import (
PaginatedMessages,
get_chat_messages_paginated,
set_turn_duration,
)
from backend.copilot.model import ChatMessage as CopilotChatMessage
from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session
def _make_msg(
sequence: int,
role: str = "assistant",
content: str | None = "hello",
tool_calls: Any = None,
) -> PrismaChatMessage:
"""Build a minimal PrismaChatMessage for testing."""
return PrismaChatMessage(
id=f"msg-{sequence}",
createdAt=datetime.now(UTC),
sessionId="sess-1",
role=role,
content=content,
sequence=sequence,
toolCalls=tool_calls,
name=None,
toolCallId=None,
refusal=None,
functionCall=None,
)
def _make_session(
session_id: str = "sess-1",
user_id: str = "user-1",
messages: list[PrismaChatMessage] | None = None,
) -> PrismaChatSession:
"""Build a minimal PrismaChatSession for testing."""
now = datetime.now(UTC)
session = PrismaChatSession.model_construct(
id=session_id,
createdAt=now,
updatedAt=now,
userId=user_id,
credentials={},
successfulAgentRuns={},
successfulAgentSchedules={},
totalPromptTokens=0,
totalCompletionTokens=0,
title=None,
metadata={},
Messages=messages or [],
)
return session
SESSION_ID = "sess-1"
@pytest.fixture()
def mock_db():
"""Patch ChatSession.prisma().find_first and ChatMessage.prisma().find_many.
find_first is used for the main query (session + included messages).
find_many is used only for boundary expansion queries.
"""
with (
patch.object(PrismaChatSession, "prisma") as mock_session_prisma,
patch.object(PrismaChatMessage, "prisma") as mock_msg_prisma,
):
find_first = AsyncMock()
mock_session_prisma.return_value.find_first = find_first
find_many = AsyncMock(return_value=[])
mock_msg_prisma.return_value.find_many = find_many
yield find_first, find_many
# ---------- Basic pagination ----------
@pytest.mark.asyncio
async def test_basic_page_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Messages are returned in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert isinstance(page, PaginatedMessages)
assert [m.sequence for m in page.messages] == [1, 2, 3]
assert page.has_more is False
assert page.oldest_sequence == 1
@pytest.mark.asyncio
async def test_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more is True when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.has_more is True
assert len(page.messages) == 2
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_empty_session_returns_no_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.messages == []
assert page.has_more is False
assert page.oldest_sequence is None
@pytest.mark.asyncio
async def test_before_sequence_filters_correctly(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""before_sequence is passed as a where filter inside the Messages include."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(2), _make_msg(1)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, before_sequence=5)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["where"] == {"sequence": {"lt": 5}}
@pytest.mark.asyncio
async def test_no_where_on_messages_without_before_sequence(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Without before_sequence, the Messages include has no where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""user_id adds a userId filter to the session-level where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50, user_id="user-abc")
call_kwargs = find_first.call_args
where = call_kwargs.kwargs.get("where") or call_kwargs[1].get("where")
assert where["userId"] == "user-abc"
@pytest.mark.asyncio
async def test_session_not_found_returns_none(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Returns None when session doesn't exist or user doesn't own it."""
find_first, _ = mock_db
find_first.return_value = None
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is None
@pytest.mark.asyncio
async def test_session_info_included_in_result(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""PaginatedMessages includes session metadata."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.session.session_id == SESSION_ID
# ---------- Backward boundary expansion ----------
@pytest.mark.asyncio
async def test_boundary_expansion_includes_assistant(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When page starts with a tool message, expand backward to include
the owning assistant message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.return_value = [_make_msg(3, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [3, 4, 5]
assert page.messages[0].role == "assistant"
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_boundary_expansion_includes_multiple_tool_msgs(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Boundary expansion scans past consecutive tool messages to find
the owning assistant."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(7, role="tool")],
)
find_many.return_value = [
_make_msg(6, role="tool"),
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [4, 5, 6, 7]
assert page.messages[0].role == "assistant"
@pytest.mark.asyncio
async def test_boundary_expansion_sets_has_more_when_not_at_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""After boundary expansion, has_more=True if expanded msgs aren't at seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="tool")],
)
find_many.return_value = [_make_msg(2, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is True
@pytest.mark.asyncio
async def test_boundary_expansion_no_has_more_at_conversation_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more stays False when boundary expansion reaches seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool")],
)
find_many.return_value = [_make_msg(0, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is False
assert page.oldest_sequence == 0
@pytest.mark.asyncio
async def test_no_boundary_expansion_when_first_msg_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No boundary expansion when the first message is not a tool message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="user"), _make_msg(2, role="assistant")],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert find_many.call_count == 0
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_boundary_expansion_warns_when_no_owner_found(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When boundary scan doesn't find a non-tool message, a warning is logged
and the boundary messages are still included."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
)
find_many.return_value = [_make_msg(i, role="tool") for i in range(9, -1, -1)]
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
mock_logger.warning.assert_called_once()
assert page is not None
assert page.messages[0].role == "tool"
assert len(page.messages) > 1
# ---------- Turn duration (integration tests) ----------
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_user_id):
"""set_turn_duration patches the cached session without invalidation.
Verifies that after calling set_turn_duration the Redis-cached session
reflects the updated durationMs on the last assistant message, without
the cache having been deleted and re-populated (which could race with
concurrent get_chat_session calls).
"""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
CopilotChatMessage(role="user", content="hello"),
CopilotChatMessage(role="assistant", content="hi there"),
]
session = await upsert_chat_session(session)
# Ensure the session is in cache
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
assert cached.messages[-1].duration_ms is None
# Update turn duration — should patch cache in-place
await set_turn_duration(session.session_id, 1234)
# Read from cache (not DB) — the cache should already have the update
updated = await get_chat_session(session.session_id, test_user_id)
assert updated is not None
assistant_msgs = [m for m in updated.messages if m.role == "assistant"]
assert len(assistant_msgs) == 1
assert assistant_msgs[0].duration_ms == 1234
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user_id):
"""set_turn_duration is a no-op when there are no assistant messages."""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
CopilotChatMessage(role="user", content="hello"),
]
session = await upsert_chat_session(session)
# Should not raise
await set_turn_duration(session.session_id, 5678)
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
# User message should not have durationMs
assert cached.messages[0].duration_ms is None

View File

@@ -13,7 +13,7 @@ import time
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamError
from backend.copilot.sdk import service as sdk_service
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@@ -30,57 +30,6 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Mode Routing ============ #
async def resolve_effective_mode(
mode: CopilotMode | None,
user_id: str | None,
) -> CopilotMode | None:
"""Strip ``mode`` when the user is not entitled to the toggle.
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
processor enforces the same gate server-side so an authenticated
user cannot bypass the flag by crafting a request directly.
"""
if mode is None:
return None
allowed = await is_feature_enabled(
Flag.CHAT_MODE_OPTION,
user_id or "anonymous",
default=False,
)
if not allowed:
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
return None
return mode
async def resolve_use_sdk_for_mode(
mode: CopilotMode | None,
user_id: str | None,
*,
use_claude_code_subscription: bool,
config_default: bool,
) -> bool:
"""Pick the SDK vs baseline path for a single turn.
Per-request ``mode`` wins whenever it is set (after the
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
falls back to the Claude Code subscription override, then the
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
"""
if mode == "fast":
return False
if mode == "extended_thinking":
return True
return use_claude_code_subscription or await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config_default,
)
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
@@ -301,26 +250,21 @@ class CoPilotProcessor:
if config.test_mode:
stream_fn = stream_chat_completion_dummy
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
effective_mode = None
else:
# Enforce server-side feature-flag gate so unauthorised
# users cannot force a mode by crafting the request.
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
use_sdk = await resolve_use_sdk_for_mode(
effective_mode,
entry.user_id,
use_claude_code_subscription=config.use_claude_code_subscription,
config_default=config.use_claude_agent_sdk,
use_sdk = (
config.use_claude_code_subscription
or await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else stream_chat_completion_baseline
)
log.info(
f"Using {'SDK' if use_sdk else 'baseline'} service "
f"(mode={effective_mode or 'default'})"
)
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
# Stream chat completion and publish chunks to Redis.
# stream_and_publish wraps the raw stream with registry
@@ -332,7 +276,6 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -1,175 +0,0 @@
"""Unit tests for CoPilot mode routing logic in the processor.
Tests cover the mode→service mapping:
- 'fast' → baseline service
- 'extended_thinking' → SDK service
- None → feature flag / config fallback
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.executor.processor import (
resolve_effective_mode,
resolve_use_sdk_for_mode,
)
class TestResolveUseSdkForMode:
"""Tests for the per-request mode routing logic."""
@pytest.mark.asyncio
async def test_fast_mode_uses_baseline(self):
"""mode='fast' always routes to baseline, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
"fast",
"user-1",
use_claude_code_subscription=True,
config_default=True,
)
is False
)
@pytest.mark.asyncio
async def test_extended_thinking_uses_sdk(self):
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
"extended_thinking",
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_subscription_override(self):
"""mode=None with claude_code_subscription=True routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=True,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_feature_flag(self):
"""mode=None with feature flag enabled routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
) as flag_mock:
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
flag_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_none_mode_uses_config_default(self):
"""mode=None falls back to config.use_claude_agent_sdk."""
# When LaunchDarkly returns the default (True), we expect SDK routing.
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=True,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_all_disabled(self):
"""mode=None with all flags off routes to baseline."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is False
)
class TestResolveEffectiveMode:
"""Tests for the CHAT_MODE_OPTION server-side gate."""
@pytest.mark.asyncio
async def test_none_mode_passes_through(self):
"""mode=None is returned as-is without a flag check."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode(None, "user-1") is None
flag_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_mode_stripped_when_flag_disabled(self):
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert await resolve_effective_mode("fast", "user-1") is None
assert await resolve_effective_mode("extended_thinking", "user-1") is None
@pytest.mark.asyncio
async def test_mode_preserved_when_flag_enabled(self):
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert await resolve_effective_mode("fast", "user-1") == "fast"
assert (
await resolve_effective_mode("extended_thinking", "user-1")
== "extended_thinking"
)
@pytest.mark.asyncio
async def test_anonymous_user_with_mode(self):
"""Anonymous users (user_id=None) still pass through the gate."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()

View File

@@ -9,7 +9,6 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -157,9 +156,6 @@ class CoPilotExecutionEntry(BaseModel):
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -179,7 +175,6 @@ async def enqueue_copilot_turn(
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -191,7 +186,6 @@ async def enqueue_copilot_turn(
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
"""
from backend.util.clients import get_async_copilot_queue
@@ -203,7 +197,6 @@ async def enqueue_copilot_turn(
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
mode=mode,
)
queue_client = await get_async_copilot_queue()

View File

@@ -1,123 +0,0 @@
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
from backend.copilot.executor.utils import (
COPILOT_EXECUTION_EXCHANGE,
COPILOT_EXECUTION_QUEUE_NAME,
COPILOT_EXECUTION_ROUTING_KEY,
CancelCoPilotEvent,
CoPilotExecutionEntry,
CoPilotLogMetadata,
create_copilot_queue_config,
)
class TestCoPilotExecutionEntry:
def test_basic_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
)
assert entry.session_id == "s1"
assert entry.user_id == "u1"
assert entry.message == "hello"
assert entry.is_user_message is True
assert entry.mode is None
assert entry.context is None
assert entry.file_ids is None
def test_mode_field(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="fast",
)
assert entry.mode == "fast"
entry2 = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="extended_thinking",
)
assert entry2.mode == "extended_thinking"
def test_optional_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
turn_id="t1",
context={"url": "https://example.com"},
file_ids=["f1", "f2"],
is_user_message=False,
)
assert entry.turn_id == "t1"
assert entry.context == {"url": "https://example.com"}
assert entry.file_ids == ["f1", "f2"]
assert entry.is_user_message is False
def test_serialization_roundtrip(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
mode="fast",
)
json_str = entry.model_dump_json()
restored = CoPilotExecutionEntry.model_validate_json(json_str)
assert restored == entry
class TestCancelCoPilotEvent:
def test_basic(self):
event = CancelCoPilotEvent(session_id="s1")
assert event.session_id == "s1"
def test_serialization(self):
event = CancelCoPilotEvent(session_id="s1")
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
assert restored.session_id == "s1"
class TestCreateCopilotQueueConfig:
def test_returns_valid_config(self):
config = create_copilot_queue_config()
assert len(config.exchanges) == 2
assert len(config.queues) == 2
def test_execution_queue_properties(self):
config = create_copilot_queue_config()
exec_queue = next(
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
)
assert exec_queue.durable is True
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
def test_cancel_queue_uses_fanout(self):
config = create_copilot_queue_config()
cancel_queue = next(
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
)
assert cancel_queue.exchange is not None
assert cancel_queue.exchange.type.value == "fanout"
class TestCoPilotLogMetadata:
def test_creates_logger_with_metadata(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
assert log is not None
def test_filters_none_values(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(
base_logger, session_id="s1", user_id=None, turn_id="t1"
)
assert log is not None

View File

@@ -59,16 +59,6 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
# GitHub user identity caches (keyed by user_id only, not provider tuple).
# Declared here so invalidate_user_provider_cache() can reference them.
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
)
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
@@ -76,19 +66,11 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
For GitHub specifically, also clears the git-identity caches so that
``get_github_user_git_identity()`` re-fetches the user's profile on
the next call instead of serving stale identity data.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
if provider == "github":
_gh_identity_cache.pop(user_id, None)
_gh_identity_null_cache.pop(user_id, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
@@ -141,7 +123,6 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
refresh_failed = False
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
@@ -160,7 +141,6 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
# Do NOT fall back to the stale token — it is likely expired
# or revoked. Returning None forces the caller to re-auth,
# preventing the LLM from receiving a non-functional token.
refresh_failed = True
continue
_token_cache[cache_key] = token
return token
@@ -172,12 +152,8 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
_token_cache[cache_key] = token
return token
# Only cache "not connected" when the user truly has no credentials for this
# provider. If we had OAuth credentials but refresh failed (e.g. transient
# network error, event-loop mismatch), do NOT cache the negative result —
# the next call should retry the refresh instead of being blocked for 60 s.
if not refresh_failed:
_null_cache[cache_key] = True
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
return None
@@ -195,76 +171,3 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
for var in var_names:
env[var] = token
return env
# ---------------------------------------------------------------------------
# GitHub user identity (for git committer env vars)
# ---------------------------------------------------------------------------
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
"""Fetch the GitHub user's name and email for git committer env vars.
Uses the ``/user`` GitHub API endpoint with the user's stored token.
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
connected GitHub account. Returns ``None`` otherwise.
Results are cached for 10 minutes; "not connected" results are cached for
60 s (same as null-token cache).
"""
if user_id in _gh_identity_null_cache:
return None
if cached := _gh_identity_cache.get(user_id):
return cached
token = await get_provider_token(user_id, "github")
if not token:
_gh_identity_null_cache[user_id] = True
return None
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.github.com/user",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status != 200:
logger.warning(
"[git-identity] GitHub /user returned %s for user %s",
resp.status,
user_id,
)
return None
data = await resp.json()
except Exception as exc:
logger.warning(
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
user_id,
exc,
)
return None
name = data.get("name") or data.get("login") or "AutoGPT User"
# GitHub may return email=null if the user has set their email to private.
# Fall back to the noreply address GitHub generates for every account.
email = data.get("email")
if not email:
gh_id = data.get("id", "")
login = data.get("login", "user")
email = f"{gh_id}+{login}@users.noreply.github.com"
identity = {
"GIT_AUTHOR_NAME": name,
"GIT_AUTHOR_EMAIL": email,
"GIT_COMMITTER_NAME": name,
"GIT_COMMITTER_EMAIL": email,
}
_gh_identity_cache[user_id] = identity
return identity

View File

@@ -9,8 +9,6 @@ from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_gh_identity_cache,
_gh_identity_null_cache,
_null_cache,
_token_cache,
get_integration_env_vars,
@@ -51,13 +49,9 @@ def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
class TestInvalidateUserProviderCache:
@@ -83,34 +77,6 @@ class TestInvalidateUserProviderCache:
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
def test_clears_gh_identity_cache_for_github_provider(self):
"""When provider is 'github', identity caches must also be cleared."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Old Name",
"GIT_AUTHOR_EMAIL": "old@example.com",
"GIT_COMMITTER_NAME": "Old Name",
"GIT_COMMITTER_EMAIL": "old@example.com",
}
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_cache
def test_clears_gh_identity_null_cache_for_github_provider(self):
"""When provider is 'github', the identity null-cache must also be cleared."""
_gh_identity_null_cache[_USER] = True
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_null_cache
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
"""When provider is NOT 'github', identity caches must be left alone."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Some Name",
"GIT_AUTHOR_EMAIL": "some@example.com",
"GIT_COMMITTER_NAME": "Some Name",
"GIT_COMMITTER_EMAIL": "some@example.com",
}
invalidate_user_provider_cache(_USER, "some-other-provider")
assert _USER in _gh_identity_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
@@ -163,15 +129,8 @@ class TestGetProviderToken:
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
"""On refresh failure, return None but do NOT cache in null_cache.
The user has credentials — they just couldn't be refreshed right now
(e.g. transient network error or event-loop mismatch in the copilot
executor). Caching a negative result would block all credential
lookups for 60 s even though the creds exist and may refresh fine
on the next attempt.
"""
async def test_oauth2_refresh_failure_returns_none(self):
"""On refresh failure, return None instead of caching a stale token."""
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
@@ -182,8 +141,6 @@ class TestGetProviderToken:
# Stale tokens must NOT be returned — forces re-auth.
assert result is None
# Must NOT cache negative result when refresh failed — next call retries.
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
@@ -219,96 +176,6 @@ class TestGetProviderToken:
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestThreadSafetyLocks:
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
'Future attached to a different loop' when copilot workers accessed
credentials from different event loops."""
@pytest.mark.asyncio(loop_scope="session")
async def test_store_locks_returns_per_thread_instance(self):
"""IntegrationCredentialsStore.locks() must return different instances
for different threads (via @thread_cached)."""
import asyncio
import concurrent.futures
from backend.integrations.credentials_store import IntegrationCredentialsStore
store = IntegrationCredentialsStore()
async def get_locks_id():
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await store.locks()
return id(locks)
# Get locks from main thread
main_id = await get_locks_id()
# Get locks from a worker thread
def run_in_thread():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(get_locks_id())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
worker_id = await asyncio.get_event_loop().run_in_executor(
pool, run_in_thread
)
assert main_id != worker_id, (
"Store.locks() returned the same instance across threads. "
"This would cause 'Future attached to a different loop' errors."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_manager_delegates_to_store_locks(self):
"""IntegrationCredentialsManager.locks() should delegate to store."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await manager.locks()
# Should have gotten it from the store
assert locks is not None
class TestRefreshUnlockedPath:
"""Bug reproduction: copilot worker threads need lock-free refresh because
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_if_needed_lock_false_skips_redis(self):
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
creds = _make_oauth2_creds()
mock_handler = MagicMock()
mock_handler.needs_refresh = MagicMock(return_value=False)
with patch(
"backend.integrations.creds_manager._get_provider_oauth_handler",
new_callable=AsyncMock,
return_value=mock_handler,
):
result = await manager.refresh_if_needed(_USER, creds, lock=False)
# Should return credentials without touching locks
assert result.id == creds.id
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):

View File

@@ -46,16 +46,6 @@ def _get_session_cache_key(session_id: str) -> str:
# ===================== Chat data models ===================== #
class ChatSessionMetadata(BaseModel):
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
Add new session-level flags here instead of adding DB columns —
no migration required for new fields as long as a default is provided.
"""
dry_run: bool = False
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -64,8 +54,6 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
sequence: int | None = None
duration_ms: int | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -78,54 +66,9 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
sequence=prisma_message.sequence,
duration_ms=prisma_message.durationMs,
)
def is_message_duplicate(
messages: list[ChatMessage],
role: str,
content: str,
) -> bool:
"""Check whether *content* is already present in the current pending turn.
Only inspects trailing messages that share the given *role* (i.e. the
current turn). This ensures legitimately repeated messages across different
turns are not suppressed, while same-turn duplicates from stale cache are
still caught.
"""
for m in reversed(messages):
if m.role == role:
if m.content == content:
return True
else:
break
return False
def maybe_append_user_message(
session: "ChatSession",
message: str | None,
is_user_message: bool,
) -> bool:
"""Append a user/assistant message to the session if not already present.
The route handler already persists the user message before enqueueing,
so we check trailing same-role messages to avoid re-appending when the
session cache is slightly stale.
Returns True if the message was appended, False if skipped.
"""
if not message:
return False
role = "user" if is_user_message else "assistant"
if is_message_duplicate(session.messages, role, message):
return False
session.messages.append(ChatMessage(role=role, content=message))
return True
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
@@ -145,12 +88,6 @@ class ChatSessionInfo(BaseModel):
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
metadata: ChatSessionMetadata = ChatSessionMetadata()
@property
def dry_run(self) -> bool:
"""Convenience accessor for ``metadata.dry_run``."""
return self.metadata.dry_run
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -164,10 +101,6 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Parse typed metadata from the JSON column.
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
metadata = ChatSessionMetadata.model_validate(raw_metadata)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
@@ -193,7 +126,6 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
metadata=metadata,
)
@@ -201,7 +133,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str, *, dry_run: bool) -> Self:
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -211,7 +143,6 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
)
@classmethod
@@ -599,7 +530,6 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
metadata=session.metadata,
)
existing_message_count = 0
@@ -677,27 +607,21 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id, dry_run=dry_run)
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
metadata=session.metadata,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")

View File

@@ -17,8 +17,6 @@ from .model import (
ChatSession,
Usage,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
upsert_chat_session,
)
@@ -48,7 +46,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123", dry_run=False)
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
@@ -59,7 +57,7 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -77,7 +75,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -92,7 +90,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
@@ -243,7 +241,7 @@ _raw_tc2 = {
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
@@ -256,7 +254,7 @@ def test_add_tool_call_appends_to_existing_assistant():
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
@@ -269,7 +267,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
@@ -284,7 +282,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
@@ -302,7 +300,7 @@ def test_add_tool_call_multiple_times():
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u", dry_run=False)
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
@@ -354,7 +352,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(
@@ -426,151 +424,3 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents
# --------------------------------------------------------------------------- #
# is_message_duplicate #
# --------------------------------------------------------------------------- #
def test_duplicate_detected_in_trailing_same_role():
"""Duplicate user message at the tail is detected."""
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi there"),
ChatMessage(role="user", content="yes"),
]
assert is_message_duplicate(msgs, "user", "yes") is True
def test_duplicate_not_detected_across_turns():
"""Same text in a previous turn (separated by assistant) is NOT a duplicate."""
msgs = [
ChatMessage(role="user", content="yes"),
ChatMessage(role="assistant", content="ok"),
]
assert is_message_duplicate(msgs, "user", "yes") is False
def test_no_duplicate_on_empty_messages():
"""Empty message list never reports a duplicate."""
assert is_message_duplicate([], "user", "hello") is False
def test_no_duplicate_when_content_differs():
"""Different content in the trailing same-role block is not a duplicate."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="first message"),
]
assert is_message_duplicate(msgs, "user", "second message") is False
def test_duplicate_with_multiple_trailing_same_role():
"""Detects duplicate among multiple consecutive same-role messages."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="msg1"),
ChatMessage(role="user", content="msg2"),
]
assert is_message_duplicate(msgs, "user", "msg1") is True
assert is_message_duplicate(msgs, "user", "msg2") is True
assert is_message_duplicate(msgs, "user", "msg3") is False
def test_duplicate_check_for_assistant_role():
"""Works correctly when checking assistant role too."""
msgs = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="assistant", content="how can I help?"),
]
assert is_message_duplicate(msgs, "assistant", "hello") is True
assert is_message_duplicate(msgs, "assistant", "new response") is False
def test_no_false_positive_when_content_is_none():
"""Messages with content=None in the trailing block do not match."""
msgs = [
ChatMessage(role="user", content=None),
ChatMessage(role="user", content="hello"),
]
assert is_message_duplicate(msgs, "user", "hello") is True
# None-content message should not match any string
msgs2 = [
ChatMessage(role="user", content=None),
]
assert is_message_duplicate(msgs2, "user", "hello") is False
def test_all_same_role_messages():
"""When all messages share the same role, the entire list is scanned."""
msgs = [
ChatMessage(role="user", content="first"),
ChatMessage(role="user", content="second"),
ChatMessage(role="user", content="third"),
]
assert is_message_duplicate(msgs, "user", "first") is True
assert is_message_duplicate(msgs, "user", "new") is False
# --------------------------------------------------------------------------- #
# maybe_append_user_message #
# --------------------------------------------------------------------------- #
def test_maybe_append_user_message_appends_new():
"""A new user message is appended and returns True."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
]
result = maybe_append_user_message(session, "new msg", is_user_message=True)
assert result is True
assert len(session.messages) == 2
assert session.messages[-1].role == "user"
assert session.messages[-1].content == "new msg"
def test_maybe_append_user_message_skips_duplicate():
"""A duplicate user message is skipped and returns False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=True)
assert result is False
assert len(session.messages) == 2
def test_maybe_append_user_message_none_message():
"""None/empty message returns False without appending."""
session = ChatSession.new(user_id="u", dry_run=False)
assert maybe_append_user_message(session, None, is_user_message=True) is False
assert maybe_append_user_message(session, "", is_user_message=True) is False
assert len(session.messages) == 0
def test_maybe_append_assistant_message():
"""Works for assistant role when is_user_message=False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
]
result = maybe_append_user_message(session, "response", is_user_message=False)
assert result is True
assert session.messages[-1].role == "assistant"
assert session.messages[-1].content == "response"
def test_maybe_append_assistant_skips_duplicate():
"""Duplicate assistant message is skipped."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2

View File

@@ -66,7 +66,6 @@ from pydantic import BaseModel, PrivateAttr
ToolName = Literal[
# Platform tools (must match keys in TOOL_REGISTRY)
"add_understanding",
"ask_question",
"bash_exec",
"browser_act",
"browser_navigate",
@@ -103,7 +102,6 @@ ToolName = Literal[
"web_fetch",
"write_workspace_file",
# SDK built-ins
"Agent",
"Edit",
"Glob",
"Grep",

View File

@@ -544,7 +544,6 @@ class TestApplyToolPermissions:
class TestSdkBuiltinToolNames:
def test_expected_builtins_present(self):
expected = {
"Agent",
"Read",
"Write",
"Edit",

View File

@@ -18,18 +18,6 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
- Image: `![chart](workspace://file_id#image/png)`
- Video: `![recording](workspace://file_id#video/mp4)`
### Handling binary/image data in tool outputs — CRITICAL
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
1. **NEVER** try to inline or render the base64 content in your response.
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
3. **Show** the result via the workspace download URL in Markdown: `![image](workspace://file_id#image/png)`.
### Passing large data between tools — CRITICAL
When tool outputs produce large text that you need to feed into another tool:
- **NEVER** copy-paste the full text into the next tool call argument.
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
- This avoids token limits and ensures data integrity.
### File references — @@agptfile:
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
@@ -119,28 +107,6 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
After building the file, reference it with `@@agptfile:` in other tools:
`@@agptfile:/home/user/report.md`
### Web search best practices
- If 3 similar web searches don't return the specific data you need, conclude
it isn't publicly available and work with what you have.
- Prefer fewer, well-targeted searches over many variations of the same query.
- When spawning sub-agents for research, ensure each has a distinct
non-overlapping scope to avoid redundant searches.
### Tool Discovery Priority
When the user asks to interact with a service or API, follow this order:
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
@@ -165,11 +131,6 @@ parent autopilot handles orchestration.
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### SDK tool-result files in E2B
When you `Read` an SDK tool-result file, it is automatically copied into the
sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
@@ -235,22 +196,19 @@ def _build_storage_supplement(
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
To read these files, use `Read` — it reads from the host filesystem.
### Large tool outputs saved to workspace
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
full output is in workspace storage (NOT on the local filesystem). To access it:
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
These files are on the host filesystem — `bash_exec` runs in the sandbox and
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
where SDK tool-results are NOT stored.
{_SHARED_TOOL_NOTES}{extra_notes}"""

View File

@@ -1,28 +0,0 @@
"""Tests for agent generation guide — verifies clarification section."""
from pathlib import Path
class TestAgentGenerationGuideContainsClarifySection:
"""The agent generation guide must include the clarification section."""
def test_guide_includes_clarify_section(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
assert "Before or During Building" in content
def test_guide_mentions_find_block_for_clarification(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "find_block" in clarify_section
def test_guide_mentions_ask_question_tool(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "ask_question" in clarify_section

View File

@@ -9,14 +9,11 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from enum import Enum
from prisma.models import User as PrismaUser
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -24,40 +21,6 @@ logger = logging.getLogger(__name__)
_USAGE_KEY_PREFIX = "copilot:usage"
# ---------------------------------------------------------------------------
# Subscription tier definitions
# ---------------------------------------------------------------------------
class SubscriptionTier(str, Enum):
"""Subscription tiers with increasing token allowances.
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
Once ``prisma generate`` is run, this can be replaced with::
from prisma.enums import SubscriptionTier
"""
FREE = "FREE"
PRO = "PRO"
BUSINESS = "BUSINESS"
ENTERPRISE = "ENTERPRISE"
# Multiplier applied to the base limits (from LD / config) for each tier.
# Intentionally int (not float): keeps limits as whole token counts and avoids
# floating-point rounding. If fractional multipliers are ever needed, change
# the type and round the result in get_global_rate_limits().
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
SubscriptionTier.FREE: 1,
SubscriptionTier.PRO: 5,
SubscriptionTier.BUSINESS: 20,
SubscriptionTier.ENTERPRISE: 60,
}
DEFAULT_TIER = SubscriptionTier.FREE
class UsageWindow(BaseModel):
"""Usage within a single time window."""
@@ -73,7 +36,6 @@ class CoPilotUsageStatus(BaseModel):
daily: UsageWindow
weekly: UsageWindow
tier: SubscriptionTier = DEFAULT_TIER
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
@@ -104,7 +66,6 @@ async def get_usage_status(
daily_token_limit: int,
weekly_token_limit: int,
rate_limit_reset_cost: int = 0,
tier: SubscriptionTier = DEFAULT_TIER,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
@@ -113,7 +74,6 @@ async def get_usage_status(
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
tier: The user's rate-limit tier (included in the response).
Returns:
CoPilotUsageStatus with current usage and limits.
@@ -143,7 +103,6 @@ async def get_usage_status(
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
tier=tier,
reset_cost=rate_limit_reset_cost,
)
@@ -202,9 +161,8 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
Returns False if Redis is unavailable so the caller can handle
compensation (fail-closed for billed operations, unlike the read-only
rate-limit checks which fail-open).
Fails open: returns False if Redis is unavailable (consistent with
the fail-open design of this module).
"""
now = datetime.now(UTC)
try:
@@ -384,100 +342,20 @@ async def record_token_usage(
)
class _UserNotFoundError(Exception):
"""Raised when a user record is missing or has no subscription tier.
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
decorator from storing the fallback value. This avoids a race condition
where a non-existent user's DEFAULT_TIER is cached, then the user is
created with a higher tier but receives the stale cached FREE tier for
up to 5 minutes.
"""
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached via Redis).
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups of existing users with a valid tier are cached.
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
the ``@cached`` decorator does **not** store a fallback value. This
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
raise _UserNotFoundError(user_id)
async def get_user_tier(user_id: str) -> SubscriptionTier:
"""Look up the user's rate-limit tier from the database.
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
to avoid a DB round-trip on every rate-limit check.
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
unreachable or returns an unrecognised value, so the next call retries
the query instead of serving a stale fallback for up to 5 minutes.
"""
try:
return await _fetch_user_tier(user_id)
except Exception as exc:
logger.warning(
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
user_id[:8],
DEFAULT_TIER.value,
exc,
)
return DEFAULT_TIER
# Expose cache management on the public function so callers (including tests)
# never need to reach into the private ``_fetch_user_tier``.
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.
Also invalidates the ``get_user_tier`` cache for this user so that
subsequent rate-limit checks immediately see the new tier.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
"""
await PrismaUser.prisma().update(
where={"id": user_id},
data={"subscriptionTier": tier.value},
)
# Invalidate cached tier so rate-limit checks pick up the change immediately.
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int, SubscriptionTier]:
) -> tuple[int, int]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
The base limits (from LD or config) are multiplied by the user's
tier multiplier so that higher tiers receive proportionally larger
allowances.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
(daily_token_limit, weekly_token_limit) tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
@@ -499,15 +377,7 @@ async def get_global_rate_limits(
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
weekly = config_weekly
# Apply tier multiplier
tier = await get_user_tier(user_id)
multiplier = TIER_MULTIPLIERS.get(tier, 1)
if multiplier != 1:
daily = daily * multiplier
weekly = weekly * multiplier
return daily, weekly, tier
return daily, weekly
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@ import pytest
from fastapi import HTTPException
from backend.api.features.chat.routes import reset_copilot_usage
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.util.exceptions import InsufficientBalanceError
@@ -53,18 +53,6 @@ def _mock_settings(enable_credit: bool = True):
return mock
def _mock_rate_limits(
daily: int = 2_500_000,
weekly: int = 12_500_000,
tier: SubscriptionTier = SubscriptionTier.PRO,
):
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
return patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(daily, weekly, tier)),
)
@pytest.mark.asyncio
class TestResetCopilotUsage:
async def test_feature_disabled_returns_400(self):
@@ -82,7 +70,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(daily=0),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
@@ -96,7 +83,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -126,7 +112,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -156,7 +141,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -187,7 +171,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -225,7 +208,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -246,7 +228,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -264,7 +245,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -295,7 +275,6 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),

View File

@@ -3,62 +3,26 @@
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Clarifying — Before or During Building
Use `ask_question` whenever the user's intent is ambiguous — whether
that's before starting or midway through the workflow. Common moments:
- **Before building**: output format, delivery channel, data source, or
trigger is unspecified.
- **During block discovery**: multiple blocks could fit and the user
should choose.
- **During JSON generation**: a wiring decision depends on user
preference.
Steps:
1. Call `find_block` (or another discovery tool) to learn what the
platform actually supports for the ambiguous dimension.
2. Call `ask_question` with a concrete question listing the discovered
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
which should the agent use for delivery?").
3. **Wait for the user's answer** before continuing.
**Skip this** when the goal already specifies all dimensions (e.g.
"scrape prices from Amazon and email me daily").
### Workflow for Creating/Editing Agents
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
returns the full graph structure (nodes + links). **Never edit blindly**
always inspect the current graph first so you know exactly what to change.
Avoid using `include_graph=true` with broad keyword searches, as fetching
multiple graphs at once is expensive and consumes LLM context budget.
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
3. **Find library agents**: Call `find_library_agent` to discover reusable
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
- Use block IDs from step 2 as `block_id` in nodes
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
- When editing, apply targeted changes and preserve unchanged parts
5. **Write to workspace**: Save the JSON to a workspace file so the user
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
`wait_for_result=120` to verify the agent works end-to-end.
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
found, call `edit_agent` to fix and dry-run again. Repeat until the
simulation passes or the problems are clearly unfixable.
See "REQUIRED: Dry-Run Verification Loop" section below for details.
### Agent JSON Structure
@@ -110,8 +74,8 @@ These define the agent's interface — what it accepts and what it produces.
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
- Specialized input block that presents a dropdown/select to the user
- Required `input_default` fields: `name` (str)
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
- Optional: `title`, `description`, `value` (default selection)
- Output: `result` — the user-selected value at runtime
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
@@ -252,62 +216,19 @@ call in a loop until the task is complete:
Regular blocks work exactly like sub-agents as tools — wire each input
field from `source_name: "tools"` on the Orchestrator side.
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
### Testing with Dry Run
After creating or editing an agent, you MUST dry-run it before telling the
user the agent is ready. NEVER skip this step.
After saving an agent, suggest a dry run to validate wiring without consuming
real API calls, credentials, or credits:
#### Step-by-step workflow
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
and realistic sample inputs that exercise every path in the agent. This
simulates execution using an LLM for each block — no real API calls,
credentials, or credits are consumed.
3. **Inspect output**: Examine the dry-run result for problems. If
`wait_for_result` returns only a summary, call
`view_agent_output(execution_id=..., show_execution_details=True)` to
see the full node-by-node execution trace. Look for:
- **Errors / failed nodes** — a node raised an exception or returned an
error status. Common causes: wrong `source_name`/`sink_name` in links,
missing `input_default` values, or referencing a nonexistent block output.
- **Null / empty outputs** — data did not flow through a link. Verify that
`source_name` and `sink_name` match the block schemas exactly (case-
sensitive, including nested `_#_` notation).
- **Nodes that never executed** — the node was not reached. Likely a
missing or broken link from an upstream node.
- **Unexpected values** — data arrived but in the wrong type or
structure. Check type compatibility between linked ports.
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
agent JSON, then go back to step 2.
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
or the problems are clearly unfixable. If you stop making progress,
report the remaining issues to the user and ask for guidance.
#### Good vs bad dry-run output
**Good output** (agent is ready):
- All nodes executed successfully (no errors in the execution trace)
- Data flows through every link with non-null, correctly-typed values
- The final `AgentOutputBlock` contains a meaningful result
- Status is `COMPLETED`
**Bad output** (needs fixing):
- Status is `FAILED` — check the error message for the failing node
- An output node received `null` — trace back to find the broken link
- A node received data in the wrong format (e.g. string where list expected)
- Nodes downstream of a failing node were skipped entirely
**Special block behaviour in dry-run mode:**
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
orchestrator can make LLM calls and agent executors can spawn child graphs.
Their downstream tool blocks and child-graph blocks are still simulated.
Note: real LLM inference calls are made (consuming API quota), even though
platform credits are not charged. Agent-mode iterations are capped at 1 in
dry-run to keep it fast.
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
so the LLM can produce a realistic mock response without connecting to the
MCP server.
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
sample inputs. This executes the graph with mock outputs, verifying that
links resolve correctly and required inputs are satisfied.
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
to inspect the full node-by-node execution trace. This shows what each node
received as input and produced as output, making it easy to spot wiring issues.
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
the agent JSON and re-save before suggesting a real execution.
### Example: Simple AI Text Processor

View File

@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user", dry_run=False)
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------

Some files were not shown because too many files have changed in this diff Show More