mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
2 Commits
harness
...
remove-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f20693d02b | ||
|
|
a4188c5657 |
@@ -1 +0,0 @@
|
||||
../.claude/skills
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allowedTools": [
|
||||
"Read", "Grep", "Glob",
|
||||
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
|
||||
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
|
||||
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
@@ -1,709 +0,0 @@
|
||||
---
|
||||
name: orchestrate
|
||||
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
|
||||
user-invocable: true
|
||||
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "6.0.0"
|
||||
---
|
||||
|
||||
# Orchestrate — Agent Fleet Supervisor
|
||||
|
||||
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
|
||||
|
||||
## Scripts
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
|
||||
STATE_FILE=~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
|
||||
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
|
||||
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
|
||||
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
|
||||
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
|
||||
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
|
||||
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
|
||||
| `status.sh` | Print fleet status + live pane commands |
|
||||
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
|
||||
| `classify-pane.sh WINDOW` | Classify one pane state |
|
||||
|
||||
## Supervision model
|
||||
|
||||
```
|
||||
Orchestrating Claude (this Claude session — IS the supervisor)
|
||||
└── Reads pane output, checks CI, intervenes with targeted guidance
|
||||
run-loop.sh (separate tmux window, every 30s)
|
||||
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
|
||||
```
|
||||
|
||||
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
|
||||
|
||||
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
|
||||
|
||||
## Checkpoint protocol
|
||||
|
||||
Agents output checkpoints as they complete each required step:
|
||||
|
||||
```
|
||||
CHECKPOINT:<step-name>
|
||||
```
|
||||
|
||||
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
|
||||
|
||||
## Worktree lifecycle
|
||||
|
||||
```text
|
||||
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
|
||||
↓
|
||||
CHECKPOINT:<step> (as steps complete)
|
||||
↓
|
||||
ORCHESTRATOR:DONE
|
||||
↓
|
||||
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
|
||||
↓
|
||||
state → "done", notify, window KEPT OPEN
|
||||
↓
|
||||
user/orchestrator explicitly requests recycle
|
||||
↓
|
||||
recycle-agent.sh → spare/N (free again)
|
||||
```
|
||||
|
||||
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
|
||||
|
||||
**To resume a done or crashed session:**
|
||||
```bash
|
||||
# Resume by stored session ID (preferred — exact session, full context)
|
||||
claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
|
||||
# Or resume most recent session in that worktree directory
|
||||
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
|
||||
```
|
||||
|
||||
**To manually recycle when ready:**
|
||||
```bash
|
||||
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
|
||||
# Then update state:
|
||||
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## State file (`~/.claude/orchestrator-state.json`)
|
||||
|
||||
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`).
|
||||
|
||||
```json
|
||||
{
|
||||
"active": true,
|
||||
"tmux_session": "autogpt1",
|
||||
"idle_threshold_seconds": 300,
|
||||
"loop_window": "autogpt1:5",
|
||||
"repo": "Significant-Gravitas/AutoGPT",
|
||||
"discord_webhook": "https://discord.com/api/webhooks/...",
|
||||
"last_poll_at": 0,
|
||||
"agents": [
|
||||
{
|
||||
"window": "autogpt1:3",
|
||||
"worktree": "AutoGPT6",
|
||||
"worktree_path": "/path/to/AutoGPT6",
|
||||
"spare_branch": "spare/6",
|
||||
"branch": "feat/my-feature",
|
||||
"objective": "Implement X and open a PR",
|
||||
"pr_number": "12345",
|
||||
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"steps": ["pr-address", "pr-test"],
|
||||
"checkpoints": ["pr-address"],
|
||||
"state": "running",
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": 0,
|
||||
"spawned_at": 0,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Top-level optional fields:
|
||||
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
|
||||
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
|
||||
|
||||
Per-agent fields:
|
||||
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
|
||||
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
|
||||
|
||||
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
|
||||
|
||||
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
|
||||
|
||||
## Serial /pr-test rule
|
||||
|
||||
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
|
||||
|
||||
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
|
||||
|
||||
You (the orchestrating Claude) own the test queue:
|
||||
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
|
||||
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
|
||||
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
|
||||
4. Feed results back to the relevant agent via `tmux send-keys`:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
5. Wait for CI to confirm green before marking the agent done.
|
||||
|
||||
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
|
||||
|
||||
## Session restore (tested and confirmed)
|
||||
|
||||
Agent sessions are saved to disk. To restore a closed or crashed session:
|
||||
|
||||
```bash
|
||||
# If session_id is in state (preferred):
|
||||
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
|
||||
|
||||
# If no session_id (use --continue for most recent session in that directory):
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
|
||||
```
|
||||
|
||||
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
|
||||
|
||||
```bash
|
||||
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
|
||||
'(.agents[] | select(.window == $old)).window = $new' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## Intent → action mapping
|
||||
|
||||
Match the user's message to one of these intents:
|
||||
|
||||
| The user says something like… | What to do |
|
||||
|---|---|
|
||||
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
|
||||
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
|
||||
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
|
||||
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
|
||||
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
|
||||
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
|
||||
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
|
||||
|
||||
When the intent is ambiguous, show capacity first and ask what tasks to run.
|
||||
|
||||
## Spawning agents
|
||||
|
||||
### 1. Resolve tmux session
|
||||
|
||||
```bash
|
||||
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
|
||||
```
|
||||
|
||||
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
|
||||
|
||||
### 2. Show available capacity
|
||||
|
||||
```bash
|
||||
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
|
||||
```
|
||||
|
||||
### 3. Collect tasks from the user
|
||||
|
||||
For each task, gather:
|
||||
- **objective** — what to do (e.g. "implement feature X and open a PR")
|
||||
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
|
||||
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
|
||||
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
|
||||
|
||||
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
|
||||
|
||||
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
|
||||
|
||||
### 4. Spawn one agent per task
|
||||
|
||||
```bash
|
||||
# Get ordered list of spare worktrees
|
||||
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
|
||||
|
||||
# For each task, take the next spare line:
|
||||
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
|
||||
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
|
||||
|
||||
# With PR number and required steps:
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
|
||||
|
||||
# Without PR (new work):
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
|
||||
```
|
||||
|
||||
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
|
||||
|
||||
```bash
|
||||
# Derive repo from git remote (used by verify-complete.sh + supervisor)
|
||||
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
|
||||
jq -n \
|
||||
--arg session "$SESSION" \
|
||||
--arg repo "$REPO" \
|
||||
--argjson threshold 300 \
|
||||
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
|
||||
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
|
||||
> ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
Optionally add a Discord webhook for completion notifications:
|
||||
```bash
|
||||
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
|
||||
|
||||
### 5. Start the mechanical babysitter
|
||||
|
||||
```bash
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
### 6. Begin supervising directly in this conversation
|
||||
|
||||
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
|
||||
|
||||
## Adding an agent
|
||||
|
||||
Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user.
|
||||
|
||||
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
|
||||
|
||||
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
|
||||
|
||||
### Poll loop mechanism
|
||||
|
||||
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
|
||||
|
||||
1. Start each poll with `run_in_background: true` + a sleep before the work:
|
||||
```bash
|
||||
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
|
||||
# + similar for each active window
|
||||
```
|
||||
2. When the background job notifies you, read the pane output and take action.
|
||||
3. Immediately schedule the next background poll — this keeps the loop alive.
|
||||
4. Stop scheduling when all agents are done/escalated.
|
||||
|
||||
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
|
||||
|
||||
### Each poll: what to check
|
||||
|
||||
```bash
|
||||
# 1. Read state
|
||||
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
|
||||
|
||||
# 2. For each running/stuck/idle agent, capture pane
|
||||
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
|
||||
```
|
||||
|
||||
For each agent, decide:
|
||||
|
||||
| What you see | Action |
|
||||
|---|---|
|
||||
| Spinner / tools running | Do nothing — agent is working |
|
||||
| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
|
||||
| Stuck in error loop | Send targeted fix with exact error + solution |
|
||||
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
|
||||
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
|
||||
| GitHub abuse rate limit error | Nudge: "Wait 60 seconds then continue posting replies with sleep 3 between each" |
|
||||
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
|
||||
| `ORCHESTRATOR:DONE` in output | Query GraphQL for actual unresolved count. If >0, re-brief. If 0, run `verify-complete.sh` |
|
||||
|
||||
**Poll all windows from state, not from memory.** Before each poll, run:
|
||||
```bash
|
||||
jq -r '.agents[] | select(.state | test("running|idle|stuck|waiting_approval|pending_evaluation")) | .window' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
and capture every window listed. If you manually added a window outside spawn-agent.sh, ensure it's in the state file first.
|
||||
|
||||
### RUNNING count includes waiting_approval agents
|
||||
|
||||
The `RUNNING` count from run-loop.sh includes agents in `waiting_approval` state (they match the regex `running|stuck|waiting_approval|idle`). This means a fleet that is only `waiting_approval` still shows RUNNING > 0 in the log — it does **not** mean agents are actively working.
|
||||
|
||||
When you see `RUNNING > 0` in the run-loop log but suspect agents are actually blocked, check state directly:
|
||||
```bash
|
||||
jq '.agents[] | {window, state, worktree}' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
A count of `running=1 waiting=1` in the log actually means one agent is waiting for approval — the orchestrator should check and approve, not wait.
|
||||
|
||||
### State file staleness recovery
|
||||
|
||||
The state file is written by scripts but can drift from reality when windows are closed, sessions expire, or the orchestrator restarts across conversations.
|
||||
|
||||
**Signs of stale state:**
|
||||
- `loop_window` points to a window that no longer exists in the tmux session
|
||||
- An agent's `state` is `running` but tmux window is closed or shows a shell prompt (not claude)
|
||||
- `last_seen_at` is hours old but state still says `running`
|
||||
|
||||
**Recovery steps:**
|
||||
|
||||
1. **Verify actual tmux windows:**
|
||||
```bash
|
||||
tmux list-windows -t SESSION -F '#{window_index}: #{window_name} (#{pane_current_command})'
|
||||
```
|
||||
|
||||
2. **Cross-reference with state file:**
|
||||
```bash
|
||||
jq -r '.agents[] | "\(.window) \(.state) \(.worktree)"' ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
3. **Fix stale entries:**
|
||||
```bash
|
||||
# Agent window closed — mark idle so run-loop.sh will restart it
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window==$w)).state = "idle"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
# loop_window gone — kill the stale reference, then restart run-loop.sh
|
||||
jq '.loop_window = null' ~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
4. **After any state repair, re-run `status.sh` to confirm coherence before resuming supervision.**
|
||||
|
||||
### Strict ORCHESTRATOR:DONE gate
|
||||
|
||||
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
|
||||
|
||||
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=~/.claude/orchestrator/scripts
|
||||
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
|
||||
```
|
||||
|
||||
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
|
||||
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
|
||||
|
||||
### Re-brief a stalled agent
|
||||
|
||||
**Before sending any nudge, verify the pane is at an idle ❯ prompt.** Sending text into a still-processing pane produces stuck `[Pasted text +N lines]` that the agent never sees.
|
||||
|
||||
Check:
|
||||
```bash
|
||||
tmux capture-pane -t SESSION:WIN -p 2>/dev/null | tail -5
|
||||
```
|
||||
If the last line shows a spinner (✳✽✢✶·), `Running…`, or no `❯` — wait 10–15s and check again before sending.
|
||||
|
||||
```bash
|
||||
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
|
||||
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
|
||||
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
|
||||
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
|
||||
|
||||
## Self-recovery protocol (agents)
|
||||
|
||||
spawn-agent.sh automatically includes this instruction in every objective:
|
||||
|
||||
> If your context compacts and you lose track of what to do, run:
|
||||
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
|
||||
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
|
||||
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
|
||||
|
||||
## Passing images and screenshots to agents
|
||||
|
||||
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
|
||||
|
||||
1. **Save the image to a temp file** with a stable path:
|
||||
```bash
|
||||
# If the user drags in a screenshot or you receive a file path:
|
||||
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
|
||||
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
|
||||
```
|
||||
|
||||
2. **Reference the path in the objective string**:
|
||||
```bash
|
||||
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
|
||||
```
|
||||
|
||||
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
|
||||
|
||||
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator final evaluation (YOU decide, not the script)
|
||||
|
||||
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
|
||||
|
||||
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
|
||||
|
||||
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
|
||||
|
||||
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
|
||||
|
||||
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
|
||||
```
|
||||
- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/NNNN — <feature description>
|
||||
- [ ] /pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/MMMM — <feature description>
|
||||
```
|
||||
Run one at a time. Check off as you go.
|
||||
|
||||
```
|
||||
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
|
||||
```
|
||||
|
||||
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
|
||||
|
||||
```
|
||||
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
|
||||
Context: This PR implements <objective from state file>. Key files: <list>.
|
||||
Please verify: <specific behaviors to check>.
|
||||
```
|
||||
|
||||
Only one `/pr-test` at a time — they share ports and DB.
|
||||
|
||||
### /pr-test result evaluation
|
||||
|
||||
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
|
||||
|
||||
| `/pr-test` result | Action |
|
||||
|---|---|
|
||||
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
|
||||
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
|
||||
| Any headline scenario **FAIL** | Re-brief the agent immediately |
|
||||
|
||||
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
|
||||
|
||||
**When any headline scenario is PARTIAL or FAIL:**
|
||||
|
||||
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
|
||||
2. Re-brief the agent with the specific scenario that failed and what was missing:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
3. Set state back to `running`:
|
||||
```bash
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
|
||||
|
||||
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
|
||||
|
||||
> **Why this matters**: A PR was once wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
|
||||
|
||||
### 2. Do your own evaluation
|
||||
|
||||
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
|
||||
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
|
||||
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
|
||||
4. **Check the PR description** — title, summary, test plan complete?
|
||||
|
||||
### 3. Decide
|
||||
|
||||
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
|
||||
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
|
||||
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
|
||||
|
||||
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
|
||||
|
||||
```bash
|
||||
# Mark done after your positive evaluation:
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## When to stop the fleet
|
||||
|
||||
Stop the fleet (`active = false`) when **all** of the following are true:
|
||||
|
||||
| Check | How to verify |
|
||||
|---|---|
|
||||
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
|
||||
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
|
||||
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
|
||||
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
|
||||
| No agents are `escalated` without human review | If any are escalated, surface to user first |
|
||||
|
||||
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
|
||||
|
||||
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
|
||||
|
||||
```bash
|
||||
# Graceful stop
|
||||
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
|
||||
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
|
||||
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
|
||||
```
|
||||
|
||||
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
|
||||
|
||||
## tmux send-keys pattern
|
||||
|
||||
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
|
||||
|
||||
```bash
|
||||
# CORRECT — text then Enter separately
|
||||
tmux send-keys -t "$WINDOW" "your long message here"
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# WRONG — Enter may fire before text is buffered
|
||||
tmux send-keys -t "$WINDOW" "your long message here" Enter
|
||||
```
|
||||
|
||||
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
|
||||
|
||||
---
|
||||
|
||||
## Protected worktrees
|
||||
|
||||
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
|
||||
|
||||
| Worktree | Protected branch | Why |
|
||||
|---|---|---|
|
||||
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
|
||||
|
||||
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
|
||||
|
||||
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
|
||||
|
||||
---
|
||||
|
||||
## Thread resolution integrity (critical)
|
||||
|
||||
**Agents MUST NOT resolve review threads via GraphQL unless a real code fix has been committed and pushed first.**
|
||||
|
||||
This is the most common failure mode: agents call `resolveReviewThread` to make unresolved counts drop without actually fixing anything. This produces a false "done" signal that gets past verify-complete.sh.
|
||||
|
||||
**The only valid resolution sequence:**
|
||||
1. Read the thread and understand what it's asking
|
||||
2. Make the actual code change
|
||||
3. `git commit` and `git push`
|
||||
4. Reply to the thread with the commit SHA (e.g. "Fixed in `abc1234`")
|
||||
5. THEN call `resolveReviewThread`
|
||||
|
||||
**The supervisor must verify actual thread counts via GraphQL** — never trust an agent's claim of "0 unresolved." After any agent's ORCHESTRATOR:DONE, always run:
|
||||
|
||||
```bash
|
||||
# Step 1: get total count
|
||||
TOTAL=$(gh api graphql -f query='{ repository(owner: "OWNER", name: "REPO") { pullRequest(number: PR) { reviewThreads { totalCount } } } }' \
|
||||
| jq '.data.repository.pullRequest.reviewThreads.totalCount')
|
||||
echo "Total threads: $TOTAL"
|
||||
|
||||
# Step 2: paginate all pages and count unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="{ repository(owner: \"OWNER\", name: \"REPO\") { pullRequest(number: PR) { reviewThreads(first: 100${AFTER}) { pageInfo { hasNextPage endCursor } nodes { isResolved } } } } }")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved: $UNRESOLVED"
|
||||
```
|
||||
|
||||
If unresolved > 0, the agent is NOT done — re-brief with the actual count and the rule.
|
||||
|
||||
**Include this in every agent objective:**
|
||||
> IMPORTANT: Do NOT resolve any review thread via GraphQL unless the code fix is committed and pushed first. Fix the code → commit → push → reply with SHA → then resolve. Never resolve without a real commit. "Accepted" or "Acknowledged" replies are NOT resolutions — only real commits qualify.
|
||||
|
||||
### Detecting fake resolutions
|
||||
|
||||
When an agent claims "0 unresolved threads", query GitHub GraphQL yourself and also inspect how each thread was resolved. A resolved thread whose last comment is `"Acknowledged"`, `"Same as above"`, `"Accepted trade-off"`, or `"Deferred"` — with no commit SHA — is a fake resolution.
|
||||
|
||||
To spot these, paginate all pages and collect resolved threads with missing SHA links:
|
||||
```bash
|
||||
# Paginate all pages — first:100 misses threads beyond page 1 on large PRs
|
||||
CURSOR=""; FAKE_RESOLUTIONS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: PR_NUMBER) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
isResolved
|
||||
comments(last: 1) {
|
||||
nodes { body author { login } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
PAGE_FAKES=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[]
|
||||
| select(.isResolved == true)
|
||||
| {body: .comments.nodes[0].body[:120], author: .comments.nodes[0].author.login}
|
||||
| select(.body | test("Fixed in|Removed in|Addressed in") | not)]')
|
||||
FAKE_RESOLUTIONS=$(echo "$FAKE_RESOLUTIONS $PAGE_FAKES" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "$FAKE_RESOLUTIONS"
|
||||
```
|
||||
Any resolved thread whose last comment does NOT contain `"Fixed in"`, `"Removed in"`, or `"Addressed in"` (with a commit link) should be investigated — either the agent falsely resolved it, or it was a genuine false positive that needs explanation.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Two distinct rate limits exist with different recovery times:
|
||||
|
||||
| Error | HTTP status | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` in body | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `API rate limit exceeded` | 429 | Primary rate limit — too many read calls per hour | Wait until `X-RateLimit-Reset` timestamp |
|
||||
|
||||
**Prevention:** Agents must add `sleep 3` between individual thread reply API calls. For >20 unresolved threads, increase to `sleep 5`.
|
||||
|
||||
If you see a 403 `abuse` error from an agent's pane:
|
||||
1. Nudge the agent: `"You hit a GitHub secondary rate limit (403). Stop all API writes. Wait 2 minutes, then resume with sleep 3 between each thread reply."`
|
||||
2. Do NOT nudge again during the 2-minute wait — a second nudge restarts the clock.
|
||||
|
||||
Add this to agent briefings when there are >20 unresolved threads:
|
||||
> Post replies with `sleep 3` between each reply. If you hit a 403 abuse error, wait 2 minutes (not 60s — secondary limits take longer to clear) then continue.
|
||||
|
||||
## Key rules
|
||||
|
||||
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
|
||||
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
|
||||
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
|
||||
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
|
||||
5. **Always `--permission-mode bypassPermissions`** on every spawn
|
||||
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
|
||||
7. **Atomic state writes** — always write to `.tmp` then `mv`
|
||||
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
|
||||
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
|
||||
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
|
||||
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
|
||||
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
|
||||
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
|
||||
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
|
||||
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
|
||||
16. **Poll ALL windows from state file** — never hardcode window count. Derive active windows dynamically: `jq -r '.agents[] | select(.state | test("running|idle|stuck")) | .window' ~/.claude/orchestrator-state.json`. If you added a window mid-session outside spawn-agent.sh, add it to the state file immediately.
|
||||
20. **Orchestrator handles its own approvals** — when spawning a subagent to make edits (SKILL.md, scripts, config), review the diff yourself and approve/reject without surfacing it to the user. The user should never have to open a file to check the orchestrator's work. Use the Agent tool with `subagent_type: general-purpose` for drafting, then verify the result yourself before considering the task done.
|
||||
17. **Update state file on re-task** — whenever an agent is re-tasked mid-session (objective changes, new PR assigned), update the state file record immediately so objectives stay accurate for re-briefing after compaction.
|
||||
18. **No GraphQL resolveReviewThread without a commit** — see Thread resolution integrity above. This is rule #1 for pr-address work.
|
||||
19. **Verify thread counts yourself** — after any agent claims "0 unresolved threads", query GitHub GraphQL directly before accepting. Never trust the agent's self-report.
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
|
||||
#
|
||||
# Usage: capacity.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree of the current git repo.
|
||||
#
|
||||
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
|
||||
echo "=== Available (spare) worktrees ==="
|
||||
if [ -n "$REPO_ROOT" ]; then
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
|
||||
else
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$SPARE" ]; then
|
||||
echo " (none)"
|
||||
else
|
||||
while IFS= read -r line; do
|
||||
[ -z "$line" ] && continue
|
||||
echo " ✓ $line"
|
||||
done <<< "$SPARE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== In-use worktrees ==="
|
||||
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
|
||||
"$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -n "$IN_USE" ]; then
|
||||
echo "$IN_USE"
|
||||
else
|
||||
echo " (none)"
|
||||
fi
|
||||
else
|
||||
echo " (no active state file)"
|
||||
fi
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# classify-pane.sh — Classify the current state of a tmux pane
|
||||
#
|
||||
# Usage: classify-pane.sh <tmux-target>
|
||||
# tmux-target: e.g. "work:0", "work:1.0"
|
||||
#
|
||||
# Output (stdout): JSON object:
|
||||
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
|
||||
#
|
||||
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TARGET="${1:-}"
|
||||
|
||||
if [ -z "$TARGET" ]; then
|
||||
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate tmux target format: session:window or session:window.pane
|
||||
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check session exists (use %%:* to extract session name from session:window)
|
||||
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
|
||||
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the current foreground command in the pane
|
||||
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
|
||||
|
||||
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
|
||||
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
|
||||
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
|
||||
| grep -v '^[[:space:]]*$' || true)
|
||||
|
||||
# --- Check: explicit completion marker ---
|
||||
# Must be on its own line (not buried in the objective text sent at spawn time).
|
||||
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
|
||||
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Check: Claude Code approval prompt patterns ---
|
||||
LAST_40=$(echo "$CLEAN" | tail -40)
|
||||
APPROVAL_PATTERNS=(
|
||||
"Do you want to proceed"
|
||||
"Do you want to make this"
|
||||
"\\[y/n\\]"
|
||||
"\\[Y/n\\]"
|
||||
"\\[n/Y\\]"
|
||||
"Proceed\\?"
|
||||
"Allow this command"
|
||||
"Run bash command"
|
||||
"Allow bash"
|
||||
"Would you like"
|
||||
"Press enter to continue"
|
||||
"Esc to cancel"
|
||||
)
|
||||
for pattern in "${APPROVAL_PATTERNS[@]}"; do
|
||||
if echo "$LAST_40" | grep -qiE "$pattern"; then
|
||||
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
|
||||
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Check: shell prompt (claude has exited) ---
|
||||
# If the foreground process is a shell (not claude/node), the agent has exited
|
||||
case "$PANE_CMD" in
|
||||
zsh|bash|fish|sh|dash|tcsh|ksh)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
# Agent is still running (claude/node/python is the foreground process)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# find-spare.sh — list worktrees on spare/N branches (free to use)
|
||||
#
|
||||
# Usage: find-spare.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree containing the current git repo.
|
||||
#
|
||||
# Output (stdout): one line per available worktree: "PATH BRANCH"
|
||||
# e.g.: /Users/me/Code/AutoGPT3 spare/3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
if [ -z "$REPO_ROOT" ]; then
|
||||
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git -C "$REPO_ROOT" worktree list --porcelain \
|
||||
| awk '
|
||||
/^worktree / { path = substr($0, 10) }
|
||||
/^branch / { branch = substr($0, 8); print path " " branch }
|
||||
' \
|
||||
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
|
||||
| sed 's|refs/heads/||'
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# notify.sh — send a fleet notification message
|
||||
#
|
||||
# Delivery order (first available wins):
|
||||
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
|
||||
# 2. macOS notification center — osascript (silent fail if unavailable)
|
||||
# 3. Stdout only
|
||||
#
|
||||
# Usage: notify.sh MESSAGE
|
||||
# Exit: always 0 (notification failure must not abort the caller)
|
||||
|
||||
MESSAGE="${1:-}"
|
||||
[ -z "$MESSAGE" ] && exit 0
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# --- Resolve Discord webhook ---
|
||||
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
|
||||
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
|
||||
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# --- Discord delivery ---
|
||||
if [ -n "$WEBHOOK" ]; then
|
||||
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
|
||||
curl -s -X POST "$WEBHOOK" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" > /dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
# --- macOS notification center (silent if not macOS or osascript missing) ---
|
||||
if command -v osascript &>/dev/null 2>&1; then
|
||||
# Escape single quotes for AppleScript
|
||||
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
|
||||
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Always print to stdout so run-loop.sh logs it
|
||||
echo "$MESSAGE"
|
||||
exit 0
|
||||
@@ -1,257 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# poll-cycle.sh — Single orchestrator poll cycle
|
||||
#
|
||||
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
|
||||
# and outputs a JSON array of actions for Claude to take.
|
||||
#
|
||||
# Usage: poll-cycle.sh
|
||||
# Output (stdout): JSON array of action objects
|
||||
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
|
||||
# "worktree": "...", "objective": "...", "reason": "..." }]
|
||||
#
|
||||
# The state file is updated in-place (atomic write via .tmp).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
|
||||
|
||||
# Cross-platform md5: always outputs just the hex digest
|
||||
md5_hash() {
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum | awk '{print $1}'
|
||||
else
|
||||
md5 | awk '{print $NF}'
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
|
||||
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
|
||||
|
||||
# Ensure state file exists
|
||||
if [ ! -f "$STATE_FILE" ]; then
|
||||
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Validate JSON upfront before any jq reads that run under set -e.
|
||||
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
|
||||
# abort the script at the ACTIVE read below without emitting any JSON output.
|
||||
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
|
||||
if [ "$ACTIVE" != "true" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NOW=$(date +%s)
|
||||
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
|
||||
|
||||
ACTIONS="[]"
|
||||
UPDATED_AGENTS="[]"
|
||||
|
||||
# Read agents as newline-delimited JSON objects.
|
||||
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
|
||||
# so we suppress that exit code and separately validate the file is well-formed JSON.
|
||||
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
|
||||
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
fi
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "$AGENTS_JSON" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while IFS= read -r agent; do
|
||||
[ -z "$agent" ] && continue
|
||||
|
||||
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
|
||||
WINDOW=$(echo "$agent" | jq -r '.window // ""')
|
||||
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
|
||||
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
|
||||
STATE=$(echo "$agent" | jq -r '.state // "running"')
|
||||
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
|
||||
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
|
||||
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
|
||||
|
||||
# Validate window format to prevent tmux target injection.
|
||||
# Allow session:window (numeric or named) and session:window.pane
|
||||
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo "Skipping agent with invalid window value: $WINDOW" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Pass-through terminal-state agents
|
||||
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Classify pane.
|
||||
# classify-pane.sh always emits JSON before exit (even on error), so using
|
||||
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
|
||||
# Use "|| true" inside the substitution so set -euo pipefail does not abort
|
||||
# the poll cycle when classify exits with a non-zero status code.
|
||||
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
|
||||
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
|
||||
|
||||
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
|
||||
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
|
||||
|
||||
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
|
||||
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
|
||||
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
|
||||
|
||||
# --- Checkpoint tracking ---
|
||||
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
|
||||
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
|
||||
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
|
||||
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
|
||||
if [ -n "$RAW" ]; then
|
||||
FOUND_CPS=$(echo "$RAW" \
|
||||
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
|
||||
| sed 's/CHECKPOINT://' \
|
||||
| sort -u \
|
||||
| jq -R . | jq -s . 2>/dev/null || echo "[]")
|
||||
NEW_CHECKPOINTS_JSON=$(jq -n \
|
||||
--argjson existing "$EXISTING_CPS" \
|
||||
--argjson found "$FOUND_CPS" \
|
||||
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
|
||||
fi
|
||||
|
||||
# Compute content hash for stuck-detection (only for running agents)
|
||||
CURRENT_HASH=""
|
||||
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
|
||||
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
|
||||
fi
|
||||
|
||||
NEW_STATE="$STATE"
|
||||
NEW_IDLE_SINCE="$IDLE_SINCE"
|
||||
NEW_REVISION_COUNT="$REVISION_COUNT"
|
||||
ACTION="none"
|
||||
REASON="$PANE_REASON"
|
||||
|
||||
case "$PANE_STATE" in
|
||||
complete)
|
||||
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
|
||||
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
|
||||
NEW_STATE="pending_evaluation"
|
||||
ACTION="complete" # run-loop logs it but takes no action
|
||||
;;
|
||||
waiting_approval)
|
||||
NEW_STATE="waiting_approval"
|
||||
ACTION="approve"
|
||||
;;
|
||||
idle)
|
||||
# Agent process has exited — needs restart
|
||||
NEW_STATE="idle"
|
||||
ACTION="kick"
|
||||
REASON="agent exited (shell is foreground)"
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
fi
|
||||
;;
|
||||
running)
|
||||
# Clear idle_since only when transitioning from idle (agent was kicked and
|
||||
# restarted). Do NOT reset for stuck — idle_since must persist across polls
|
||||
# so STUCK_DURATION can accumulate and trigger escalation.
|
||||
# Also update the local IDLE_SINCE so the hash-stability check below uses
|
||||
# the reset value on this same poll, not the stale kick timestamp.
|
||||
if [[ "$STATE" == "idle" ]]; then
|
||||
NEW_IDLE_SINCE=0
|
||||
IDLE_SINCE=0
|
||||
fi
|
||||
# Check if hash has been stable (agent may be stuck mid-task)
|
||||
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
|
||||
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
else
|
||||
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
|
||||
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
else
|
||||
NEW_STATE="stuck"
|
||||
ACTION="kick"
|
||||
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# Only reset the idle timer when we have a valid hash comparison (pane
|
||||
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
|
||||
# preserve existing timers so stuck detection is not inadvertently reset.
|
||||
if [ -n "$CURRENT_HASH" ]; then
|
||||
NEW_STATE="running"
|
||||
NEW_IDLE_SINCE=0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
error)
|
||||
REASON="classify error: $PANE_REASON"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build updated agent record (ensure idle_since and revision_count are numeric)
|
||||
# Use || true on each jq call so a malformed field skips this agent rather than
|
||||
# aborting the entire poll cycle under set -e.
|
||||
UPDATED_AGENT=$(echo "$agent" | jq \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg hash "$CURRENT_HASH" \
|
||||
--argjson now "$NOW" \
|
||||
--arg idle_since "$NEW_IDLE_SINCE" \
|
||||
--arg revision_count "$NEW_REVISION_COUNT" \
|
||||
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
|
||||
'.state = $state
|
||||
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
|
||||
| .last_seen_at = $now
|
||||
| .idle_since = ($idle_since | tonumber)
|
||||
| .revision_count = ($revision_count | tonumber)
|
||||
| .checkpoints = $checkpoints' 2>/dev/null) || {
|
||||
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
}
|
||||
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
|
||||
|
||||
# Add action if needed
|
||||
if [ "$ACTION" != "none" ]; then
|
||||
ACTION_OBJ=$(jq -n \
|
||||
--arg window "$WINDOW" \
|
||||
--arg action "$ACTION" \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg worktree "$WORKTREE" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg reason "$REASON" \
|
||||
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
|
||||
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
|
||||
fi
|
||||
|
||||
done <<< "$AGENTS_JSON"
|
||||
|
||||
# Atomic state file update
|
||||
jq --argjson agents "$UPDATED_AGENTS" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents = $agents | .last_poll_at = $now' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
|
||||
echo "$ACTIONS"
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
|
||||
#
|
||||
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
|
||||
# WINDOW — tmux target, e.g. autogpt1:3
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — branch to restore, e.g. spare/6
|
||||
#
|
||||
# Stdout: one status line
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WINDOW="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
|
||||
# Kill the tmux window (ignore error — may already be gone)
|
||||
tmux kill-window -t "$WINDOW" 2>/dev/null || true
|
||||
|
||||
# Restore to spare branch: abort any in-progress operation, then clean
|
||||
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
|
||||
|
||||
echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)"
|
||||
@@ -1,215 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
|
||||
#
|
||||
# Handles ONLY two things that need no intelligence:
|
||||
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
|
||||
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
|
||||
#
|
||||
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
|
||||
# marking done, deciding to close windows — is the orchestrating Claude's job.
|
||||
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
|
||||
# the orchestrator's background poll loop handles it from there.
|
||||
#
|
||||
# Usage: run-loop.sh
|
||||
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Copy scripts to a stable location outside the repo so they survive branch
|
||||
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
|
||||
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
|
||||
# current branch).
|
||||
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
|
||||
mkdir -p "$STABLE_SCRIPTS_DIR"
|
||||
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
|
||||
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
|
||||
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
# Adaptive polling: starts at base interval, backs off up to POLL_IDLE_MAX when
|
||||
# no agents need attention, resets on any activity or waiting_approval state.
|
||||
POLL_INTERVAL="${POLL_INTERVAL:-30}"
|
||||
POLL_IDLE_MAX=${POLL_IDLE_MAX:-300}
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_state WINDOW FIELD VALUE
|
||||
# ---------------------------------------------------------------------------
|
||||
update_state() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --arg v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
update_state_int() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
agent_field() {
|
||||
jq -r --arg w "$1" --arg f "$2" \
|
||||
'.agents[] | select(.window == $w) | .[$f] // ""' \
|
||||
"$STATE_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_prompt() {
|
||||
local window="$1"
|
||||
for i in $(seq 1 60); do
|
||||
local cmd pane
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter; sleep 2; continue
|
||||
fi
|
||||
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0
|
||||
sleep 1
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_claude_idle WINDOW — wait up to 30s for Claude to reach idle ❯ prompt
|
||||
# (no spinner or busy indicator visible in the last 3 lines of pane output)
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_claude_idle() {
|
||||
local window="$1"
|
||||
local timeout="${2:-30}"
|
||||
local elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can scroll above last 3 lines.
|
||||
# Do NOT reset elapsed — resetting allows an infinite loop if the dialog never clears.
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
# Must be running under node (Claude is live)
|
||||
if [[ "$cmd" == "node" ]]; then
|
||||
# Idle: ❯ prompt visible AND no spinner/busy text in last 3 lines
|
||||
if echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
sleep 2
|
||||
(( elapsed += 2 ))
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_kick() {
|
||||
local window="$1" state="$2"
|
||||
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
|
||||
|
||||
local worktree_path session_id
|
||||
worktree_path=$(agent_field "$window" "worktree_path")
|
||||
session_id=$(agent_field "$window" "session_id")
|
||||
|
||||
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
|
||||
|
||||
# Wait for the shell prompt before typing — avoids sending into a still-draining pane
|
||||
wait_for_claude_idle "$window" 30 \
|
||||
|| echo "[$(date +%H:%M:%S)] KICK WARNING $window — pane still busy before resume, sending anyway"
|
||||
|
||||
# Resume the exact session so the agent retains full context — no need to re-send objective
|
||||
if [ -n "$session_id" ]; then
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
|
||||
else
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
|
||||
fi
|
||||
|
||||
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_approve WINDOW — auto-approve dialogs that need no judgment
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_approve() {
|
||||
local window="$1"
|
||||
local pane_tail
|
||||
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
|
||||
|
||||
# Settings error dialog at startup
|
||||
if echo "$pane_tail" | grep -q "Enter to confirm"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Numbered-option dialog (e.g. "Do you want to make this edit?")
|
||||
# ❯ is already on option 1 (Yes) — Enter confirms it
|
||||
if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
|
||||
tmux send-keys -t "$window" "" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# y/n prompt for safe operations
|
||||
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
|
||||
tmux send-keys -t "$window" "y" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Anything else — supervisor handles it, just log
|
||||
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll ${POLL_INTERVAL}s→${POLL_IDLE_MAX}s adaptive)"
|
||||
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
|
||||
echo "---"
|
||||
|
||||
while true; do
|
||||
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "[$(date +%H:%M:%S)] active=false — exiting."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
|
||||
KICKED=0; DONE=0
|
||||
|
||||
while IFS= read -r action; do
|
||||
[ -z "$action" ] && continue
|
||||
WINDOW=$(echo "$action" | jq -r '.window // ""')
|
||||
ACTION=$(echo "$action" | jq -r '.action // ""')
|
||||
STATE=$(echo "$action" | jq -r '.state // ""')
|
||||
|
||||
case "$ACTION" in
|
||||
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
|
||||
approve) handle_approve "$WINDOW" || true ;;
|
||||
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
|
||||
esac
|
||||
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
|
||||
|
||||
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
|
||||
"$STATE_FILE" 2>/dev/null || echo 0)
|
||||
|
||||
# Adaptive backoff: reset to base on activity or waiting_approval agents; back off when truly idle
|
||||
WAITING=$(jq '[.agents[] | select(.state == "waiting_approval")] | length' "$STATE_FILE" 2>/dev/null || echo 0)
|
||||
if (( KICKED > 0 || DONE > 0 || WAITING > 0 )); then
|
||||
POLL_CURRENT=$POLL_INTERVAL
|
||||
else
|
||||
POLL_CURRENT=$(( POLL_CURRENT + POLL_CURRENT / 2 + 1 ))
|
||||
(( POLL_CURRENT > POLL_IDLE_MAX )) && POLL_CURRENT=$POLL_IDLE_MAX
|
||||
fi
|
||||
|
||||
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled (next in ${POLL_CURRENT}s)"
|
||||
sleep "$POLL_CURRENT"
|
||||
done
|
||||
@@ -1,129 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
|
||||
#
|
||||
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
|
||||
# SESSION — tmux session name, e.g. autogpt1
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
|
||||
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
|
||||
# OBJECTIVE — task description sent to the agent
|
||||
# PR_NUMBER — (optional) GitHub PR number for completion verification
|
||||
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
|
||||
#
|
||||
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
|
||||
# Exit non-zero on failure.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SESSION="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
NEW_BRANCH="$4"
|
||||
OBJECTIVE="$5"
|
||||
PR_NUMBER="${6:-}"
|
||||
STEPS=("${@:7}")
|
||||
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# Generate a stable session ID so this agent's Claude session can always be resumed:
|
||||
# claude --resume $SESSION_ID --permission-mode bypassPermissions
|
||||
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
|
||||
|
||||
# Create (or switch to) the task branch
|
||||
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|
||||
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
|
||||
|
||||
# Open a new named tmux window; capture its numeric index
|
||||
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
|
||||
WINDOW="${SESSION}:${WIN_IDX}"
|
||||
|
||||
# Append the initial agent record to the state file so subsequent jq updates find it.
|
||||
# This must happen before the pr_number/steps update below.
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
NOW=$(date +%s)
|
||||
jq --arg window "$WINDOW" \
|
||||
--arg worktree "$WORKTREE_NAME" \
|
||||
--arg worktree_path "$WORKTREE_PATH" \
|
||||
--arg spare_branch "$SPARE_BRANCH" \
|
||||
--arg branch "$NEW_BRANCH" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg session_id "$SESSION_ID" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents += [{
|
||||
"window": $window,
|
||||
"worktree": $worktree,
|
||||
"worktree_path": $worktree_path,
|
||||
"spare_branch": $spare_branch,
|
||||
"branch": $branch,
|
||||
"objective": $objective,
|
||||
"session_id": $session_id,
|
||||
"state": "running",
|
||||
"checkpoints": [],
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": $now,
|
||||
"spawned_at": $now,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
|
||||
# The agent record was appended above so the jq select now finds it.
|
||||
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
|
||||
if [ "${#STEPS[@]}" -gt 0 ]; then
|
||||
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
|
||||
else
|
||||
STEPS_JSON='[]'
|
||||
fi
|
||||
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
|
||||
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Launch claude with a stable session ID so it can always be resumed after a crash:
|
||||
# claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
|
||||
|
||||
# wait_for_claude_idle — poll until the pane shows idle ❯ with no spinner in the last 3 lines.
|
||||
# Returns 0 when idle, 1 on timeout.
|
||||
_wait_idle() {
|
||||
local window="$1" timeout="${2:-60}" elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
local cmd pane_tail
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
pane_tail=$(echo "$pane" | tail -3)
|
||||
# Check full pane (not just tail) — 'Enter to confirm' dialog can appear above the last 3 lines
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
sleep 2; (( elapsed += 2 )); continue
|
||||
fi
|
||||
if [[ "$cmd" == "node" ]] && \
|
||||
echo "$pane_tail" | grep -q "❯" && \
|
||||
! echo "$pane_tail" | grep -qE '[✳✽✢✶·✻✼✿❋✤]|Running…|Compacting'; then
|
||||
return 0
|
||||
fi
|
||||
sleep 2; (( elapsed += 2 ))
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
# Wait up to 60s for claude to be fully interactive and idle (❯ visible, no spinner).
|
||||
if ! _wait_idle "$WINDOW" 60; then
|
||||
echo "[spawn-agent] WARNING: timed out waiting for idle ❯ prompt on $WINDOW — sending objective anyway" >&2
|
||||
fi
|
||||
|
||||
# Send the task. Split text and Enter — if combined, Enter can fire before the string
|
||||
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
|
||||
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# Only output the window address — nothing else (callers parse this)
|
||||
echo "$WINDOW"
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# status.sh — print orchestrator status: state file summary + live tmux pane commands
|
||||
#
|
||||
# Usage: status.sh
|
||||
# Reads: ~/.claude/orchestrator-state.json
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "No orchestrator state found at $STATE_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Header: active status, session, thresholds, last poll
|
||||
jq -r '
|
||||
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
|
||||
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
|
||||
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
|
||||
""
|
||||
' "$STATE_FILE"
|
||||
|
||||
# Each agent: state, window, worktree/branch, truncated objective
|
||||
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
|
||||
if [ "$AGENT_COUNT" -eq 0 ]; then
|
||||
echo " (no agents registered)"
|
||||
else
|
||||
jq -r '
|
||||
.agents[] |
|
||||
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
|
||||
" \(.objective // "" | .[0:70])"
|
||||
' "$STATE_FILE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# Live pane_current_command for non-done agents
|
||||
while IFS= read -r WINDOW; do
|
||||
[ -z "$WINDOW" ] && continue
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
|
||||
echo " $WINDOW live: $CMD"
|
||||
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)
|
||||
@@ -1,180 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# verify-complete.sh — verify a PR task is truly done before marking the agent done
|
||||
#
|
||||
# Check order matters:
|
||||
# 1. Checkpoints — did the agent do all required steps?
|
||||
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
|
||||
# 3. CI passing — no failures (agent must fix before done)
|
||||
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
|
||||
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
|
||||
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
|
||||
#
|
||||
# Usage: verify-complete.sh WINDOW
|
||||
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WINDOW="$1"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
|
||||
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
|
||||
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
|
||||
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
|
||||
|
||||
# No PR number = cannot verify
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 1: all required steps are checkpointed ---
|
||||
MISSING=""
|
||||
while IFS= read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
|
||||
MISSING="$MISSING $step"
|
||||
fi
|
||||
done <<< "$STEPS"
|
||||
|
||||
if [ -n "$MISSING" ]; then
|
||||
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve repo for all GitHub checks below
|
||||
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
|
||||
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
|
||||
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$REPO" ]; then
|
||||
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
|
||||
|
||||
# --- Check 2: CI fully complete — no pending checks ---
|
||||
# Pending checks MUST finish before we check threads/reviews:
|
||||
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
|
||||
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$PENDING" -gt 0 ]; then
|
||||
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 3: CI passing — no failures ---
|
||||
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$FAILING" -gt 0 ]; then
|
||||
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
|
||||
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
|
||||
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
|
||||
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
|
||||
if [ -n "$LATEST_RUN_AT" ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
|
||||
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
OWNER=$(echo "$REPO" | cut -d/ -f1)
|
||||
REPONAME=$(echo "$REPO" | cut -d/ -f2)
|
||||
|
||||
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
|
||||
UNRESOLVED=$(gh api graphql -f query="
|
||||
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
|
||||
pullRequest(number: ${PR_NUMBER}) {
|
||||
reviewThreads(first: 50) { nodes { isResolved } }
|
||||
}
|
||||
}
|
||||
}
|
||||
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
|
||||
|
||||
if [ "$UNRESOLVED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
|
||||
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
|
||||
# Stale reviews (pre-dating the fixing commits) should not block verification.
|
||||
#
|
||||
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
|
||||
# treat that as NOT COMPLETE rather than silently passing.
|
||||
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
|
||||
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
|
||||
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
|
||||
# PR activity (bot comments, label changes) which would create false negatives.
|
||||
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
|
||||
--json commits,latestReviews 2>/dev/null) || {
|
||||
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
|
||||
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
|
||||
|
||||
BLOCKING_CHANGES_REQUESTED=0
|
||||
BLOCKING_REQUESTERS=""
|
||||
|
||||
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
|
||||
while IFS= read -r review; do
|
||||
[ -z "$review" ] && continue
|
||||
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
|
||||
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
|
||||
if [ -z "$REVIEW_DATE" ]; then
|
||||
# No submission date — treat as fresh (conservative: blocks verification)
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
else
|
||||
if date --version >/dev/null 2>&1; then
|
||||
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
|
||||
# Review was submitted AFTER latest commit — still fresh, blocks verification
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
fi
|
||||
# Review submitted BEFORE latest commit — stale, skip
|
||||
fi
|
||||
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
|
||||
else
|
||||
# No commit date or no changes_requested — check raw count as fallback
|
||||
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
|
||||
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
|
||||
exit 0
|
||||
@@ -17,110 +17,43 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
> ⚠️ **WARNING — PAGINATE ALL PAGES BEFORE ADDRESSING ANYTHING**
|
||||
>
|
||||
> `reviewThreads(first: 100)` returns at most 100 threads per page AND returns threads **oldest-first**. On a PR with many review cycles (e.g. 373 threads), the oldest 100–200 threads are from past cycles and are **all already resolved**. Filtering client-side with `select(.isResolved == false)` on page 1 therefore yields **0 results** — even though pages 2–4 contain many unresolved threads from recent review cycles.
|
||||
>
|
||||
> **This is the most common failure mode:** agent fetches page 1, sees 0 unresolved after filtering, stops pagination, reports "done" — while hundreds of unresolved threads sit on later pages.
|
||||
>
|
||||
> One observed PR had 142 total threads: page 1 returned 0 unresolved (all old/resolved), while pages 2–3 had 111 unresolved. Another with 373 threads across 4 pages also had page 1 entirely resolved.
|
||||
>
|
||||
> **The rule: ALWAYS paginate to `hasNextPage == false` regardless of the per-page unresolved count. Never stop early because a page returns 0 unresolved.**
|
||||
|
||||
**Step 1 — Fetch total count and sanity-check the newest threads:**
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
|
||||
```bash
|
||||
# Get total count and the newest 100 threads (last: 100 returns newest-first)
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads { totalCount }
|
||||
newest: reviewThreads(last: 100) {
|
||||
nodes { isResolved }
|
||||
}
|
||||
}
|
||||
}
|
||||
}' | jq '{ total: .data.repository.pullRequest.reviewThreads.totalCount, newest_unresolved: [.data.repository.pullRequest.newest.nodes[] | select(.isResolved == false)] | length }'
|
||||
```
|
||||
|
||||
If `total > 100`, you have multiple pages — you **must** paginate all of them regardless of what `newest_unresolved` shows. The `last: 100` check is a sanity signal only; the full loop below is mandatory.
|
||||
|
||||
**Step 2 — Collect all unresolved thread IDs across all pages:**
|
||||
|
||||
```bash
|
||||
# Accumulate all unresolved threads — loop until hasNextPage == false
|
||||
CURSOR=""
|
||||
ALL_THREADS="[]"
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
line
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } }
|
||||
}
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
# Append unresolved nodes from this page
|
||||
PAGE_THREADS=$(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)]')
|
||||
ALL_THREADS=$(echo "$ALL_THREADS $PAGE_THREADS" | jq -s 'add')
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
|
||||
# Reverse so newest threads (last pages) are addressed first — GitHub returns oldest-first
|
||||
# and the most recent review cycle's comments are the ones blocking approval.
|
||||
ALL_THREADS=$(echo "$ALL_THREADS" | jq 'reverse')
|
||||
|
||||
echo "Total unresolved threads: $(echo "$ALL_THREADS" | jq 'length')"
|
||||
echo "$ALL_THREADS" | jq '[.[] | {id, path, line, body: .comments.nodes[0].body[:200]}]'
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**Step 3 — Address every thread in `ALL_THREADS`, then resolve.**
|
||||
|
||||
Only after this loop completes (all pages fetched, count confirmed) should you begin making fixes.
|
||||
|
||||
> **Why reverse?** GraphQL returns threads oldest-first and exposes no `orderBy` option. A PR with 373 threads has ~4 pages; threads from the latest review cycle land on the last pages. Processing in reverse ensures the newest, most blocking comments are addressed first — the earlier pages mostly contain outdated threads from prior cycles.
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
@@ -139,71 +72,20 @@ Two things to extract:
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits.**
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
**CRITICAL: The only valid sequence is fix → commit → push → reply → resolve. Never resolve a thread without a real code commit.**
|
||||
|
||||
Resolving a thread via `resolveReviewThread` without an actual fix is the most common failure mode — it makes unresolved counts drop without any real change, producing a false "done" signal. If the issue was genuinely a false positive (no code change needed), reply explaining why and then resolve. Otherwise:
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → resolve.
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
Use a **markdown commit link** so GitHub renders it as a clickable reference. Always get the full SHA with `git rev-parse HEAD` **after** committing — never copy a SHA from a previous commit or hardcode one:
|
||||
|
||||
```bash
|
||||
FULL_SHA=$(git rev-parse HEAD)
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): <description>"
|
||||
```
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
|
||||
### What counts as a valid resolution
|
||||
|
||||
Only two situations justify calling `resolveReviewThread`:
|
||||
|
||||
1. **Real code fix**: you changed the code, committed + pushed, and replied with the SHA. The commit diff must actually address the concern — not just touch the same file.
|
||||
2. **Genuine false positive**: the reviewer's concern does not apply to this code, and you can give a specific technical reason (e.g. "Not applicable — `sdk_cwd` is pre-validated by `_make_sdk_cwd()` which applies normpath + prefix assertion before reaching this point").
|
||||
|
||||
**Anti-patterns that look resolved but aren't — never do these:**
|
||||
- `"Accepted, tracked as follow-up"` — a deferral, not a fix. The concern is still open. Do not resolve.
|
||||
- `"Acknowledged"` or `"Same as above"` — these are acknowledgements, not fixes. Do not resolve.
|
||||
- `"Fixed in abc1234"` where `abc1234` is a commit that doesn't actually change the flagged line/logic — dishonest. Verify `git show abc1234 -- path/to/file` changes the right thing before posting.
|
||||
- Resolving without replying — the reviewer never sees what happened.
|
||||
|
||||
When in doubt: if a code change is needed, make it. A deferred issue means the thread stays open until the follow-up PR is merged.
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
|
||||
|
||||
### Running coverage locally
|
||||
|
||||
**Backend** (from `autogpt_platform/backend/`):
|
||||
```bash
|
||||
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
|
||||
```
|
||||
|
||||
**Frontend** (from `autogpt_platform/frontend/`):
|
||||
```bash
|
||||
pnpm vitest run --coverage
|
||||
```
|
||||
|
||||
### When codecov/patch fails
|
||||
|
||||
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
|
||||
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
|
||||
3. Run coverage locally to verify, commit, push.
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -223,28 +105,10 @@ kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## Coverage
|
||||
|
||||
Codecov enforces patch coverage on new/changed lines — new code you write must be tested. Before pushing, verify you haven't left new lines uncovered:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run pytest --cov=. --cov-report=term-missing {path/to/changed/module}
|
||||
```
|
||||
|
||||
Look for lines marked `miss` — those are uncovered. Add tests for any new code you wrote as part of addressing comments.
|
||||
|
||||
**Rules:**
|
||||
- New code you add should have tests
|
||||
- Don't remove existing tests when fixing comments
|
||||
- If a reviewer asks you to delete code, also delete its tests, but verify coverage hasn't dropped on remaining lines
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
@@ -334,162 +198,3 @@ git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
|
||||
|
||||
| Error | HTTP code | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
|
||||
|
||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||
|
||||
### Detection
|
||||
|
||||
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
|
||||
|
||||
```bash
|
||||
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
|
||||
# Human-readable reset:
|
||||
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
|
||||
```
|
||||
|
||||
Retry when `remaining > 0`. If you need to proceed sooner, sleep 2–5 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
|
||||
|
||||
### What keeps working
|
||||
|
||||
When GraphQL is unavailable (rate-limited or outage):
|
||||
|
||||
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
|
||||
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
|
||||
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
|
||||
|
||||
### Fall back to REST
|
||||
|
||||
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
|
||||
```
|
||||
|
||||
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
|
||||
|
||||
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
|
||||
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
|
||||
```
|
||||
|
||||
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
|
||||
|
||||
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
|
||||
|
||||
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
|
||||
|
||||
### Recovery from secondary rate limit (403 abuse)
|
||||
|
||||
1. Stop all API writes immediately
|
||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||
3. Resume with `sleep 3` between each call
|
||||
4. If 403 persists after 2 min, wait another 2 min before retrying
|
||||
|
||||
Never batch all replies in a tight loop — always space them out.
|
||||
|
||||
## Parallel thread resolution
|
||||
|
||||
When a PR has more than 10 unresolved threads, addressing one commit per thread is slow. Use this strategy instead:
|
||||
|
||||
### Group by file, batch per commit
|
||||
|
||||
1. Sort `ALL_THREADS` by `path` — threads in the same file can share a single commit.
|
||||
2. Fix all threads in one file → `git commit` → `git push` → reply to **all** those threads with the same SHA → resolve them all.
|
||||
3. Move to the next file group and repeat.
|
||||
|
||||
This reduces N commits to (number of files touched), which is usually 3–5 instead of 15–30.
|
||||
|
||||
### Posting replies concurrently (for large batches)
|
||||
|
||||
For truly independent thread groups (different files, no shared logic), you can post replies in parallel using background subshells — but always space out API writes:
|
||||
|
||||
```bash
|
||||
# Post replies to a batch of threads concurrently, 3s apart
|
||||
(
|
||||
sleep 3
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID1}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
(
|
||||
sleep 6
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID2}/replies \
|
||||
-f body="🤖 Fixed in [${FULL_SHA:0:9}](https://github.com/Significant-Gravitas/AutoGPT/commit/${FULL_SHA}): ..."
|
||||
) &
|
||||
wait # wait for all background replies before resolving
|
||||
```
|
||||
|
||||
Then resolve sequentially (GraphQL mutations):
|
||||
```bash
|
||||
for THREAD_ID in "$THREAD1" "$THREAD2" "$THREAD3"; do
|
||||
gh api graphql -f query="mutation { resolveReviewThread(input: {threadId: \"${THREAD_ID}\"}) { thread { isResolved } } }"
|
||||
sleep 3
|
||||
done
|
||||
```
|
||||
|
||||
**Always sleep 3s between individual API writes** — GitHub's secondary rate limit (403) triggers on bursts of >20 writes. Increase to `sleep 5` when posting more than 20 replies in a batch.
|
||||
|
||||
## Resolving threads via GraphQL
|
||||
|
||||
Use `resolveReviewThread` **only after** the commit is pushed and the reply is posted:
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREAD_ID"}) { thread { isResolved } } }'
|
||||
```
|
||||
|
||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||
|
||||
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
|
||||
|
||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||
|
||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||
|
||||
```bash
|
||||
# Step 1: get total thread count
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads { totalCount }
|
||||
}
|
||||
}
|
||||
}' | jq '.data.repository.pullRequest.reviewThreads.totalCount'
|
||||
|
||||
# Step 2: paginate all pages, count truly unresolved
|
||||
CURSOR=""; UNRESOLVED=0
|
||||
while true; do
|
||||
AFTER=${CURSOR:+", after: \"$CURSOR\""}
|
||||
PAGE=$(gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100${AFTER}) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes { isResolved }
|
||||
}
|
||||
}
|
||||
}
|
||||
}")
|
||||
UNRESOLVED=$(( UNRESOLVED + $(echo "$PAGE" | jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved==false)] | length') ))
|
||||
HAS_NEXT=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage')
|
||||
CURSOR=$(echo "$PAGE" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.endCursor')
|
||||
[ "$HAS_NEXT" = "false" ] && break
|
||||
done
|
||||
echo "Unresolved threads: $UNRESOLVED"
|
||||
```
|
||||
|
||||
Only output `ORCHESTRATOR:DONE` after this loop reports 0.
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
---
|
||||
name: pr-polish
|
||||
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Polish
|
||||
|
||||
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
|
||||
|
||||
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
|
||||
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
|
||||
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
|
||||
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
|
||||
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
|
||||
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
|
||||
|
||||
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice.
|
||||
|
||||
## TodoWrite
|
||||
|
||||
Before starting, write two todos so the user can see the loop progression:
|
||||
|
||||
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
|
||||
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
|
||||
|
||||
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
ARG_PR="${ARG:-}"
|
||||
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
|
||||
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
|
||||
ARG_PR="${BASH_REMATCH[1]}"
|
||||
fi
|
||||
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
|
||||
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
|
||||
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
|
||||
exit 1
|
||||
fi
|
||||
echo "Polishing PR #$PR"
|
||||
```
|
||||
|
||||
## The outer loop
|
||||
|
||||
```text
|
||||
round = 0
|
||||
while round < _MAX_ROUNDS:
|
||||
round += 1
|
||||
baseline = snapshot_state(PR) # see "Snapshotting state" below
|
||||
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
|
||||
findings = diff_state(PR, baseline)
|
||||
if findings.total == 0:
|
||||
break # no new findings → go to polish polling
|
||||
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
|
||||
# Post-loop: polish polling (see below).
|
||||
polish_polling(PR)
|
||||
```
|
||||
|
||||
### Snapshotting state
|
||||
|
||||
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
|
||||
|
||||
```bash
|
||||
# Inline threads — total count + latest databaseId per thread
|
||||
gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: ${PR}) {
|
||||
reviewThreads(first: 100) {
|
||||
totalCount
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
comments(last: 1) { nodes { databaseId } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}" > /tmp/baseline_threads.json
|
||||
|
||||
# Top-level reviews — count + latest id per non-empty review
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
|
||||
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
|
||||
> /tmp/baseline_reviews.json
|
||||
|
||||
# Issue comments — count + latest id per non-bot, non-author comment.
|
||||
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
|
||||
# accounts like coderabbitai, github-actions, sentry-io). The author is
|
||||
# filtered by comparing login to the PR author — export it so jq can see it.
|
||||
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
|
||||
--jq --arg author "$AUTHOR" \
|
||||
'[.[] | select(.user.type != "Bot" and .user.login != $author)
|
||||
| {id, user: .user.login, created_at}]' \
|
||||
> /tmp/baseline_issue_comments.json
|
||||
```
|
||||
|
||||
### Diffing after a review
|
||||
|
||||
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
|
||||
|
||||
- New inline thread `id` not in the baseline.
|
||||
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
|
||||
- A new top-level review `id` with a non-empty body.
|
||||
- A new issue comment `id` from a non-bot, non-author user.
|
||||
|
||||
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
|
||||
|
||||
## Polish polling
|
||||
|
||||
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals:
|
||||
|
||||
```text
|
||||
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
|
||||
clean_polls = 0
|
||||
required_clean = 2
|
||||
while clean_polls < required_clean:
|
||||
# 1. CI gate — any terminal non-success conclusion (not just "failure")
|
||||
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
|
||||
# anything else (including cancelled, timed_out, action_required) is a
|
||||
# blocker that won't self-resolve.
|
||||
ci = fetch_check_runs(PR)
|
||||
if any ci.conclusion in NON_SUCCESS_TERMINAL:
|
||||
invoke_skill("pr-address", PR) # address failures + any new comments
|
||||
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
|
||||
clean_polls = 0
|
||||
continue
|
||||
if any ci.conclusion is None (still in_progress):
|
||||
sleep 60; continue # wait without counting this as clean
|
||||
|
||||
# 2. Comment / thread gate
|
||||
threads = fetch_unresolved_threads(PR)
|
||||
new_issue_comments = diff_against_baseline(issue_comments)
|
||||
new_reviews = diff_against_baseline(reviews)
|
||||
if threads or new_issue_comments or new_reviews:
|
||||
invoke_skill("pr-address", PR)
|
||||
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
|
||||
# otherwise they stay "new" relative to the old baseline forever
|
||||
clean_polls = 0
|
||||
continue
|
||||
|
||||
# 3. Mergeability gate
|
||||
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
|
||||
if mergeable == false (CONFLICTING):
|
||||
resolve_conflicts(PR) # see pr-address skill
|
||||
clean_polls = 0
|
||||
continue
|
||||
if mergeable is null (UNKNOWN):
|
||||
sleep 60; continue
|
||||
|
||||
clean_polls += 1
|
||||
sleep 60
|
||||
```
|
||||
|
||||
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||
|
||||
### Why 2 clean polls, not 1
|
||||
|
||||
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||
|
||||
### Why checking every source each poll
|
||||
|
||||
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
|
||||
|
||||
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
|
||||
- Issue comments from human reviewers (not caught by inline thread polling).
|
||||
- Sentry bug predictions that land on new line numbers post-push.
|
||||
- Merge conflicts introduced by a race between your push and a merge to `dev`.
|
||||
|
||||
## Invocation pattern
|
||||
|
||||
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
|
||||
|
||||
```python
|
||||
Skill(skill="pr-review", args=pr_url)
|
||||
Skill(skill="pr-address", args=pr_url)
|
||||
```
|
||||
|
||||
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
|
||||
|
||||
### **Auto-continue: do NOT end your response between child skills**
|
||||
|
||||
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
|
||||
|
||||
- Do NOT summarize and stop.
|
||||
- Do NOT wait for user confirmation to continue.
|
||||
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
|
||||
|
||||
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
|
||||
|
||||
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||
|
||||
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
|
||||
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
|
||||
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
|
||||
|
||||
## Safety valves
|
||||
|
||||
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
|
||||
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
|
||||
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
|
||||
|
||||
## Reporting
|
||||
|
||||
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
|
||||
|
||||
```
|
||||
PR #{N} polish complete ({rounds_completed} rounds):
|
||||
- {X} inline threads opened and resolved
|
||||
- {Y} CI failures fixed
|
||||
- {Z} new commits pushed
|
||||
Final state: CI green, {total} threads all resolved, mergeable.
|
||||
```
|
||||
|
||||
If exiting via `_MAX_ROUNDS`, flag explicitly:
|
||||
|
||||
```
|
||||
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
|
||||
- {N} threads still unresolved: {titles}
|
||||
- CI status: {summary}
|
||||
Needs human review.
|
||||
```
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use when the user says any of:
|
||||
- "polish this PR"
|
||||
- "keep reviewing and addressing until it's mergeable"
|
||||
- "loop /pr-review + /pr-address until done"
|
||||
- "make sure the PR is actually merge-ready"
|
||||
|
||||
Do **not** use when:
|
||||
- User wants just one review pass (→ `/pr-review`).
|
||||
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
|
||||
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.
|
||||
@@ -17,16 +17,6 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -44,8 +34,6 @@ gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,195 +0,0 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
@@ -1,225 +0,0 @@
|
||||
---
|
||||
name: write-frontend-tests
|
||||
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
|
||||
user-invocable: true
|
||||
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Write Frontend Tests
|
||||
|
||||
Analyze the current branch's frontend changes, plan integration tests, and write them.
|
||||
|
||||
## References
|
||||
|
||||
Before writing any tests, read the testing rules and conventions:
|
||||
|
||||
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
|
||||
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
|
||||
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
|
||||
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
|
||||
|
||||
## Step 1: Identify changed frontend files
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${ARGUMENTS:-dev}"
|
||||
cd autogpt_platform/frontend
|
||||
|
||||
# Get changed frontend files (excluding generated, config, and test files)
|
||||
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
|
||||
| grep -v '__generated__' \
|
||||
| grep -v '__tests__' \
|
||||
| grep -v '\.test\.' \
|
||||
| grep -v '\.stories\.' \
|
||||
| grep -v '\.spec\.'
|
||||
```
|
||||
|
||||
Also read the diff to understand what changed:
|
||||
|
||||
```bash
|
||||
git diff "$BASE_BRANCH"...HEAD --stat -- src/
|
||||
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
```
|
||||
|
||||
## Step 2: Categorize changes and find test targets
|
||||
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Shared hooks with standalone business logic when UI-level coverage is impractical
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
|
||||
## Step 3: Check for existing tests
|
||||
|
||||
For each test target, check if tests already exist:
|
||||
|
||||
```bash
|
||||
# For a page at src/app/(platform)/library/page.tsx
|
||||
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
|
||||
|
||||
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
|
||||
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
|
||||
```
|
||||
|
||||
Note which targets have no tests (need new files) vs which have tests that need updating.
|
||||
|
||||
## Step 4: Identify API endpoints used
|
||||
|
||||
For each test target, find which API hooks are used:
|
||||
|
||||
```bash
|
||||
# Find generated API hook imports in the changed files
|
||||
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
|
||||
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
|
||||
```
|
||||
|
||||
For each API hook found, locate the corresponding MSW handler:
|
||||
|
||||
```bash
|
||||
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
|
||||
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
|
||||
```
|
||||
|
||||
List every MSW handler you will need (200 for happy path, 4xx for error paths).
|
||||
|
||||
## Step 5: Write the test plan
|
||||
|
||||
Before writing code, output a plan as a numbered list:
|
||||
|
||||
```
|
||||
Test plan for [branch name]:
|
||||
|
||||
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
|
||||
- Renders page with agent list (MSW 200)
|
||||
- Shows loading state
|
||||
- Shows error state (MSW 422)
|
||||
- Handles empty agent list
|
||||
|
||||
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
|
||||
- Filters agents by search query
|
||||
- Shows no results message
|
||||
- Clears search
|
||||
|
||||
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
|
||||
- Add test for new "duplicate" action
|
||||
```
|
||||
|
||||
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
|
||||
|
||||
## Step 6: Write the tests
|
||||
|
||||
For each test file in the plan, follow these conventions:
|
||||
|
||||
### File structure
|
||||
|
||||
```tsx
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
// Import MSW handlers for endpoints the page uses
|
||||
import {
|
||||
getGetV2ListLibraryAgentsMockHandler200,
|
||||
getGetV2ListLibraryAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/library/library.msw";
|
||||
// Import the component under test
|
||||
import LibraryPage from "../page";
|
||||
|
||||
describe("LibraryPage", () => {
|
||||
test("renders agent list from API", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler200());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/my agents/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows error state on API failure", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler422());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/error/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Rules
|
||||
|
||||
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
|
||||
- Use `server.use()` to set up MSW handlers BEFORE rendering
|
||||
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
|
||||
- Use `getBy*` only for elements that are immediately present in the DOM
|
||||
- Use `screen` queries — do NOT destructure from `render()`
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
|
||||
### Test location
|
||||
|
||||
```
|
||||
# For pages: __tests__/ next to page.tsx
|
||||
src/app/(platform)/library/__tests__/main.test.tsx
|
||||
|
||||
# For complex standalone components: __tests__/ inside component folder
|
||||
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
|
||||
|
||||
# For pure helpers: co-located .test.ts
|
||||
src/app/(platform)/library/helpers.test.ts
|
||||
```
|
||||
|
||||
### Custom MSW overrides
|
||||
|
||||
When the auto-generated faker data is not enough, override with specific data:
|
||||
|
||||
```tsx
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
);
|
||||
```
|
||||
|
||||
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
|
||||
|
||||
## Step 7: Run and verify
|
||||
|
||||
After writing all tests:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
Then run the full checks:
|
||||
|
||||
```bash
|
||||
pnpm format
|
||||
pnpm lint
|
||||
pnpm types
|
||||
```
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,12 +1,8 @@
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
78
.github/workflows/classic-autogpt-ci.yml
vendored
78
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,19 +6,11 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -27,22 +19,47 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
working-directory: classic/original_autogpt
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
steps:
|
||||
- name: Start MinIO service
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -54,23 +71,41 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python 3.12
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -81,13 +116,12 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
tests/unit tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -101,11 +135,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/logs/
|
||||
path: classic/original_autogpt/logs/
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
|
||||
tests/unit tests/integration 2>&1 | tee test_output.txt
|
||||
|
||||
test_failure=${PIPESTATUS[0]}
|
||||
|
||||
|
||||
44
.github/workflows/classic-autogpts-ci.yml
vendored
44
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -10,9 +10,10 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -20,9 +21,10 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -33,9 +35,13 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.12'
|
||||
min-python-version: '3.10'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -49,22 +55,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
|
||||
202
.github/workflows/classic-benchmark-ci.yml
vendored
202
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,24 +1,18 @@
|
||||
name: Classic - Direct Benchmark CI
|
||||
name: Classic - AGBenchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -29,16 +23,23 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.12'
|
||||
min-python-version: '3.10'
|
||||
|
||||
jobs:
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
working-directory: classic/benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -46,88 +47,71 @@ jobs:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run basic benchmark tests
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -140,31 +124,53 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
|
||||
189
.github/workflows/classic-forge-ci.yml
vendored
189
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,15 +6,13 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -23,60 +21,131 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
working-directory: classic/forge
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
steps:
|
||||
- name: Start MinIO service
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python 3.12
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: poetry run playwright install chromium
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge/forge forge/tests
|
||||
forge
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -90,11 +159,85 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/logs/
|
||||
path: classic/forge/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,9 +7,7 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- 'classic/benchmark/**'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -18,9 +16,7 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- 'classic/benchmark/**'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -31,13 +27,44 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.12"
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -54,31 +81,42 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.12"
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -95,16 +133,19 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -269,14 +269,12 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run pytest
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
@@ -289,13 +287,11 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-backend
|
||||
files: ./autogpt_platform/backend/coverage.xml
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,11 +148,3 @@ jobs:
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend
|
||||
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml
|
||||
|
||||
38
.github/workflows/platform-fullstack-ci.yml
vendored
38
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -160,7 +160,6 @@ jobs:
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -180,30 +179,21 @@ jobs:
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
|
||||
# (docker compose config on some versions drops this arg)
|
||||
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
|
||||
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
|
||||
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
|
||||
fi
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
NEXT_PUBLIC_SOURCEMAPS: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
@@ -289,38 +279,16 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Cache Playwright browsers
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
playwright-${{ runner.os }}-
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright E2E suite
|
||||
run: pnpm test:e2e:no-build
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend-e2e
|
||||
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
|
||||
disable_search: true
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@@ -3,7 +3,6 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -17,7 +16,6 @@ log-ingestion.txt
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
!autogpt_platform/frontend/public/notification.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
@@ -161,10 +159,6 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -183,16 +177,9 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
/autogpt_platform/backend/poetry.toml
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
title = "AutoGPT Gitleaks Config"
|
||||
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[allowlist]
|
||||
description = "Global allowlist"
|
||||
paths = [
|
||||
# Template/example env files (no real secrets)
|
||||
'''\.env\.(default|example|template)$''',
|
||||
# Lock files
|
||||
'''pnpm-lock\.yaml$''',
|
||||
'''poetry\.lock$''',
|
||||
# Secrets baseline
|
||||
'''\.secrets\.baseline$''',
|
||||
# Build artifacts and caches (should not be committed)
|
||||
'''__pycache__/''',
|
||||
'''classic/frontend/build/''',
|
||||
# Docker dev setup (local dev JWTs/keys only)
|
||||
'''autogpt_platform/db/docker/''',
|
||||
# Load test configs (dev JWTs)
|
||||
'''load-tests/configs/''',
|
||||
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
|
||||
'''(_test|test_.*|conftest)\.py$''',
|
||||
# Documentation (only contains placeholder keys in curl/API examples)
|
||||
'''docs/.*\.md$''',
|
||||
# Firebase config (public API keys by design)
|
||||
'''google-services\.json$''',
|
||||
'''classic/frontend/(lib|web)/''',
|
||||
]
|
||||
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
|
||||
regexes = [
|
||||
'''dvziYgz0KSK8FENhju0ZYi8''',
|
||||
# LLM model name enum values falsely flagged as API keys
|
||||
'''Llama-\d.*Instruct''',
|
||||
]
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -23,15 +23,9 @@ repos:
|
||||
- id: detect-secrets
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
args: ["--baseline", ".secrets.baseline"]
|
||||
files: ^autogpt_platform/
|
||||
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
name: Detect secrets (gitleaks)
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -90,16 +84,51 @@ repos:
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
|
||||
poetry -C classic install
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
@@ -194,10 +223,26 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -211,13 +256,26 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -253,10 +311,29 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -283,9 +360,26 @@ repos:
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic (excl. slow tests)
|
||||
# alias: pytest-classic
|
||||
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
|
||||
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitLabTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "IPPublicDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "OpenAIDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "PypiTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TelegramBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_baseline_file",
|
||||
"filename": ".secrets.baseline"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.regex.should_exclude_file",
|
||||
"pattern": [
|
||||
"\\.env$",
|
||||
"pnpm-lock\\.yaml$",
|
||||
"\\.env\\.(default|example|template)$",
|
||||
"__pycache__",
|
||||
"_test\\.py$",
|
||||
"test_.*\\.py$",
|
||||
"conftest\\.py$",
|
||||
"poetry\\.lock$",
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
],
|
||||
"results": {
|
||||
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
|
||||
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
|
||||
"is_verified": false,
|
||||
"line_number": 289
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
|
||||
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
|
||||
"is_verified": false,
|
||||
"line_number": 29
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
|
||||
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
|
||||
"is_verified": false,
|
||||
"line_number": 12
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/checks.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 108
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/ci.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
|
||||
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
|
||||
"is_verified": false,
|
||||
"line_number": 123
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
|
||||
"is_verified": false,
|
||||
"line_number": 42
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
|
||||
"is_verified": false,
|
||||
"line_number": 344
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
|
||||
"is_verified": false,
|
||||
"line_number": 534
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/docs.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
|
||||
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
|
||||
"is_verified": false,
|
||||
"line_number": 203
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
|
||||
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
|
||||
"is_verified": false,
|
||||
"line_number": 57
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
|
||||
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
|
||||
"is_verified": false,
|
||||
"line_number": 53
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/medium.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
|
||||
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
|
||||
"is_verified": false,
|
||||
"line_number": 131
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
|
||||
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
|
||||
"is_verified": false,
|
||||
"line_number": 100
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/talking_head.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
|
||||
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
|
||||
"is_verified": false,
|
||||
"line_number": 113
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
|
||||
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
|
||||
"is_verified": false,
|
||||
"line_number": 17
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/util/cache.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/util/cache.py",
|
||||
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
|
||||
"is_verified": false,
|
||||
"line_number": 449
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
|
||||
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
|
||||
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 8
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 7
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 192
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
|
||||
"is_verified": false,
|
||||
"line_number": 102
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 103
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
|
||||
"is_verified": false,
|
||||
"line_number": 73
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
|
||||
"is_verified": false,
|
||||
"line_number": 75
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
|
||||
"is_verified": false,
|
||||
"line_number": 77
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
|
||||
"is_verified": false,
|
||||
"line_number": 79
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
|
||||
"is_verified": false,
|
||||
"line_number": 81
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
|
||||
"is_verified": false,
|
||||
"line_number": 83
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/constants.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
|
||||
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
|
||||
"is_verified": false,
|
||||
"line_number": 13
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
|
||||
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
|
||||
"is_verified": false,
|
||||
"line_number": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-09T14:20:23Z"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
@@ -47,9 +47,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
|
||||
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
|
||||
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
310
WORKFLOW.md
310
WORKFLOW.md
@@ -1,310 +0,0 @@
|
||||
---
|
||||
hooks:
|
||||
after_create: |
|
||||
if command -v mise >/dev/null 2>&1; then
|
||||
if [ -f mise.toml ]; then
|
||||
mise trust
|
||||
mise exec -- mix deps.get
|
||||
elif [ -f elixir/mise.toml ]; then
|
||||
cd elixir && mise trust && mise exec -- mix deps.get
|
||||
fi
|
||||
fi
|
||||
before_remove: |
|
||||
if [ -f elixir/mix.exs ]; then
|
||||
cd elixir && mise exec -- mix workspace.before_remove
|
||||
fi
|
||||
agent:
|
||||
default_effort: medium
|
||||
max_turns: 20
|
||||
---
|
||||
|
||||
|
||||
You are working on a Linear ticket `{{ issue.identifier }}`
|
||||
|
||||
{% if attempt %}
|
||||
Continuation context:
|
||||
|
||||
- This is retry attempt #{{ attempt }} because the ticket is still in an active state.
|
||||
- Resume from the current workspace state instead of restarting from scratch.
|
||||
- Do not repeat already-completed investigation or validation unless needed for new code changes.
|
||||
- Do not end the turn while the issue remains in an active state unless you are blocked by missing required permissions/secrets.
|
||||
{% endif %}
|
||||
|
||||
Issue context:
|
||||
Identifier: {{ issue.identifier }}
|
||||
Title: {{ issue.title }}
|
||||
Current status: {{ issue.state }}
|
||||
Labels: {{ issue.labels }}
|
||||
URL: {{ issue.url }}
|
||||
|
||||
Description:
|
||||
{% if issue.description %}
|
||||
{{ issue.description }}
|
||||
{% else %}
|
||||
No description provided.
|
||||
{% endif %}
|
||||
|
||||
Instructions:
|
||||
|
||||
1. This is an unattended orchestration session. Never ask a human to perform follow-up actions.
|
||||
2. Only stop early for a true blocker (missing required auth/permissions/secrets). If blocked, record it in the workpad and move the issue according to workflow.
|
||||
3. Final message must report completed actions and blockers only. Do not include "next steps for user".
|
||||
|
||||
Work only in the provided repository copy. Do not touch any other path.
|
||||
|
||||
## Prerequisite: Linear MCP or `linear_graphql` tool is available
|
||||
|
||||
The agent should be able to talk to Linear, either via a configured Linear MCP server or injected `linear_graphql` tool. If none are present, stop and ask the user to configure Linear.
|
||||
|
||||
## Default posture
|
||||
|
||||
- Start by determining the ticket's current status, then follow the matching flow for that status.
|
||||
- Start every task by opening the tracking workpad comment and bringing it up to date before doing new implementation work.
|
||||
- Spend extra effort up front on planning and verification design before implementation.
|
||||
- Reproduce first: always confirm the current behavior/issue signal before changing code so the fix target is explicit.
|
||||
- Keep ticket metadata current (state, checklist, acceptance criteria, links).
|
||||
- Treat a single persistent Linear comment as the source of truth for progress.
|
||||
- Use that single workpad comment for all progress and handoff notes; do not post separate "done"/summary comments.
|
||||
- Treat any ticket-authored `Validation`, `Test Plan`, or `Testing` section as non-negotiable acceptance input: mirror it in the workpad and execute it before considering the work complete.
|
||||
- When meaningful out-of-scope improvements are discovered during execution,
|
||||
file a separate Linear issue instead of expanding scope. The follow-up issue
|
||||
must include a clear title, description, and acceptance criteria, be placed in
|
||||
`Backlog`, be assigned to the same project as the current issue, link the
|
||||
current issue as `related`, and use `blockedBy` when the follow-up depends on
|
||||
the current issue.
|
||||
- Move status only when the matching quality bar is met.
|
||||
- Operate autonomously end-to-end unless blocked by missing requirements, secrets, or permissions.
|
||||
- Use the blocked-access escape hatch only for true external blockers (missing required tools/auth) after exhausting documented fallbacks.
|
||||
|
||||
## Related skills
|
||||
|
||||
- `linear`: interact with Linear.
|
||||
- `commit`: produce clean, logical commits during implementation.
|
||||
- `push`: keep remote branch current and publish updates.
|
||||
- `pull`: keep branch updated with latest `origin/main` before handoff.
|
||||
- `land`: when ticket reaches `Merging`, explicitly open and follow `.codex/skills/land/SKILL.md`, which includes the `land` loop.
|
||||
|
||||
## Status map
|
||||
|
||||
- `Backlog` -> out of scope for this workflow; do not modify.
|
||||
- `Todo` -> queued; immediately transition to `In Progress` before active work.
|
||||
- Special case: if a PR is already attached, treat as feedback/rework loop (run full PR feedback sweep, address or explicitly push back, revalidate, return to `Human Review`).
|
||||
- `In Progress` -> implementation actively underway.
|
||||
- `Human Review` -> PR is attached and validated; waiting on human approval.
|
||||
- `Merging` -> approved by human; execute the `land` skill flow (do not call `gh pr merge` directly).
|
||||
- `Rework` -> reviewer requested changes; planning + implementation required.
|
||||
- `Done` -> terminal state; no further action required.
|
||||
|
||||
## Step 0: Determine current ticket state and route
|
||||
|
||||
1. Fetch the issue by explicit ticket ID.
|
||||
2. Read the current state.
|
||||
3. Route to the matching flow:
|
||||
- `Backlog` -> do not modify issue content/state; stop and wait for human to move it to `Todo`.
|
||||
- `Todo` -> immediately move to `In Progress`, then ensure bootstrap workpad comment exists (create if missing), then start execution flow.
|
||||
- If PR is already attached, start by reviewing all open PR comments and deciding required changes vs explicit pushback responses.
|
||||
- `In Progress` -> continue execution flow from current scratchpad comment.
|
||||
- `Human Review` -> wait and poll for decision/review updates.
|
||||
- `Merging` -> on entry, open and follow `.codex/skills/land/SKILL.md`; do not call `gh pr merge` directly.
|
||||
- `Rework` -> run rework flow.
|
||||
- `Done` -> do nothing and shut down.
|
||||
4. Check whether a PR already exists for the current branch and whether it is closed.
|
||||
- If a branch PR exists and is `CLOSED` or `MERGED`, treat prior branch work as non-reusable for this run.
|
||||
- Create a fresh branch from `origin/main` and restart execution flow as a new attempt.
|
||||
5. For `Todo` tickets, do startup sequencing in this exact order:
|
||||
- `update_issue(..., state: "In Progress")`
|
||||
- find/create `## Codex Workpad` bootstrap comment
|
||||
- only then begin analysis/planning/implementation work.
|
||||
6. Add a short comment if state and issue content are inconsistent, then proceed with the safest flow.
|
||||
|
||||
## Step 1: Start/continue execution (Todo or In Progress)
|
||||
|
||||
1. Find or create a single persistent scratchpad comment for the issue:
|
||||
- Search existing comments for a marker header: `## Codex Workpad`.
|
||||
- Ignore resolved comments while searching; only active/unresolved comments are eligible to be reused as the live workpad.
|
||||
- If found, reuse that comment; do not create a new workpad comment.
|
||||
- If not found, create one workpad comment and use it for all updates.
|
||||
- Persist the workpad comment ID and only write progress updates to that ID.
|
||||
2. If arriving from `Todo`, do not delay on additional status transitions: the issue should already be `In Progress` before this step begins.
|
||||
3. Immediately reconcile the workpad before new edits:
|
||||
- Check off items that are already done.
|
||||
- Expand/fix the plan so it is comprehensive for current scope.
|
||||
- Ensure `Acceptance Criteria` and `Validation` are current and still make sense for the task.
|
||||
4. Start work by writing/updating a hierarchical plan in the workpad comment.
|
||||
5. Ensure the workpad includes a compact environment stamp at the top as a code fence line:
|
||||
- Format: `<host>:<abs-workdir>@<short-sha>`
|
||||
- Example: `devbox-01:/home/dev-user/code/symphony-workspaces/MT-32@7bdde33bc`
|
||||
- Do not include metadata already inferable from Linear issue fields (`issue ID`, `status`, `branch`, `PR link`).
|
||||
6. Add explicit acceptance criteria and TODOs in checklist form in the same comment.
|
||||
- If changes are user-facing, include a UI walkthrough acceptance criterion that describes the end-to-end user path to validate.
|
||||
- If changes touch app files or app behavior, add explicit app-specific flow checks to `Acceptance Criteria` in the workpad (for example: launch path, changed interaction path, and expected result path).
|
||||
- If the ticket description/comment context includes `Validation`, `Test Plan`, or `Testing` sections, copy those requirements into the workpad `Acceptance Criteria` and `Validation` sections as required checkboxes (no optional downgrade).
|
||||
7. Run a principal-style self-review of the plan and refine it in the comment.
|
||||
8. Before implementing, capture a concrete reproduction signal and record it in the workpad `Notes` section (command/output, screenshot, or deterministic UI behavior).
|
||||
9. Run the `pull` skill to sync with latest `origin/main` before any code edits, then record the pull/sync result in the workpad `Notes`.
|
||||
- Include a `pull skill evidence` note with:
|
||||
- merge source(s),
|
||||
- result (`clean` or `conflicts resolved`),
|
||||
- resulting `HEAD` short SHA.
|
||||
10. Compact context and proceed to execution.
|
||||
|
||||
## PR feedback sweep protocol (required)
|
||||
|
||||
When a ticket has an attached PR, run this protocol before moving to `Human Review`:
|
||||
|
||||
1. Identify the PR number from issue links/attachments.
|
||||
2. Gather feedback from all channels:
|
||||
- Top-level PR comments (`gh pr view --comments`).
|
||||
- Inline review comments (`gh api repos/<owner>/<repo>/pulls/<pr>/comments`).
|
||||
- Review summaries/states (`gh pr view --json reviews`).
|
||||
3. Treat every actionable reviewer comment (human or bot), including inline review comments, as blocking until one of these is true:
|
||||
- code/test/docs updated to address it, or
|
||||
- explicit, justified pushback reply is posted on that thread.
|
||||
4. Update the workpad plan/checklist to include each feedback item and its resolution status.
|
||||
5. Re-run validation after feedback-driven changes and push updates.
|
||||
6. Repeat this sweep until there are no outstanding actionable comments.
|
||||
|
||||
## Blocked-access escape hatch (required behavior)
|
||||
|
||||
Use this only when completion is blocked by missing required tools or missing auth/permissions that cannot be resolved in-session.
|
||||
|
||||
- GitHub is **not** a valid blocker by default. Always try fallback strategies first (alternate remote/auth mode, then continue publish/review flow).
|
||||
- Do not move to `Human Review` for GitHub access/auth until all fallback strategies have been attempted and documented in the workpad.
|
||||
- If a non-GitHub required tool is missing, or required non-GitHub auth is unavailable, move the ticket to `Human Review` with a short blocker brief in the workpad that includes:
|
||||
- what is missing,
|
||||
- why it blocks required acceptance/validation,
|
||||
- exact human action needed to unblock.
|
||||
- Keep the brief concise and action-oriented; do not add extra top-level comments outside the workpad.
|
||||
|
||||
## Step 2: Execution phase (Todo -> In Progress -> Human Review)
|
||||
|
||||
1. Determine current repo state (`branch`, `git status`, `HEAD`) and verify the kickoff `pull` sync result is already recorded in the workpad before implementation continues.
|
||||
2. If current issue state is `Todo`, move it to `In Progress`; otherwise leave the current state unchanged.
|
||||
3. Load the existing workpad comment and treat it as the active execution checklist.
|
||||
- Edit it liberally whenever reality changes (scope, risks, validation approach, discovered tasks).
|
||||
4. Implement against the hierarchical TODOs and keep the comment current:
|
||||
- Check off completed items.
|
||||
- Add newly discovered items in the appropriate section.
|
||||
- Keep parent/child structure intact as scope evolves.
|
||||
- Update the workpad immediately after each meaningful milestone (for example: reproduction complete, code change landed, validation run, review feedback addressed).
|
||||
- Never leave completed work unchecked in the plan.
|
||||
- For tickets that started as `Todo` with an attached PR, run the full PR feedback sweep protocol immediately after kickoff and before new feature work.
|
||||
5. Run validation/tests required for the scope.
|
||||
- Mandatory gate: execute all ticket-provided `Validation`/`Test Plan`/ `Testing` requirements when present; treat unmet items as incomplete work.
|
||||
- Prefer a targeted proof that directly demonstrates the behavior you changed.
|
||||
- You may make temporary local proof edits to validate assumptions (for example: tweak a local build input for `make`, or hardcode a UI account / response path) when this increases confidence.
|
||||
- Revert every temporary proof edit before commit/push.
|
||||
- Document these temporary proof steps and outcomes in the workpad `Validation`/`Notes` sections so reviewers can follow the evidence.
|
||||
- If app-touching, run `launch-app` validation and capture/upload media via `github-pr-media` before handoff.
|
||||
6. Re-check all acceptance criteria and close any gaps.
|
||||
7. Before every `git push` attempt, run the required validation for your scope and confirm it passes; if it fails, address issues and rerun until green, then commit and push changes.
|
||||
8. Attach PR URL to the issue (prefer attachment; use the workpad comment only if attachment is unavailable).
|
||||
- Ensure the GitHub PR has label `symphony` (add it if missing).
|
||||
9. Merge latest `origin/main` into branch, resolve conflicts, and rerun checks.
|
||||
10. Update the workpad comment with final checklist status and validation notes.
|
||||
- Mark completed plan/acceptance/validation checklist items as checked.
|
||||
- Add final handoff notes (commit + validation summary) in the same workpad comment.
|
||||
- Do not include PR URL in the workpad comment; keep PR linkage on the issue via attachment/link fields.
|
||||
- Add a short `### Confusions` section at the bottom when any part of task execution was unclear/confusing, with concise bullets.
|
||||
- Do not post any additional completion summary comment.
|
||||
11. Before moving to `Human Review`, poll PR feedback and checks:
|
||||
- Read the PR `Manual QA Plan` comment (when present) and use it to sharpen UI/runtime test coverage for the current change.
|
||||
- Run the full PR feedback sweep protocol.
|
||||
- Confirm PR checks are passing (green) after the latest changes.
|
||||
- Confirm every required ticket-provided validation/test-plan item is explicitly marked complete in the workpad.
|
||||
- Repeat this check-address-verify loop until no outstanding comments remain and checks are fully passing.
|
||||
- Re-open and refresh the workpad before state transition so `Plan`, `Acceptance Criteria`, and `Validation` exactly match completed work.
|
||||
12. Only then move issue to `Human Review`.
|
||||
- Exception: if blocked by missing required non-GitHub tools/auth per the blocked-access escape hatch, move to `Human Review` with the blocker brief and explicit unblock actions.
|
||||
13. For `Todo` tickets that already had a PR attached at kickoff:
|
||||
- Ensure all existing PR feedback was reviewed and resolved, including inline review comments (code changes or explicit, justified pushback response).
|
||||
- Ensure branch was pushed with any required updates.
|
||||
- Then move to `Human Review`.
|
||||
|
||||
## Step 3: Human Review and merge handling
|
||||
|
||||
1. When the issue is in `Human Review`, do not code or change ticket content.
|
||||
2. Poll for updates as needed, including GitHub PR review comments from humans and bots.
|
||||
3. If review feedback requires changes, move the issue to `Rework` and follow the rework flow.
|
||||
4. If approved, human moves the issue to `Merging`.
|
||||
5. When the issue is in `Merging`, open and follow `.codex/skills/land/SKILL.md`, then run the `land` skill in a loop until the PR is merged. Do not call `gh pr merge` directly.
|
||||
6. After merge is complete, move the issue to `Done`.
|
||||
|
||||
## Step 4: Rework handling
|
||||
|
||||
1. Treat `Rework` as a full approach reset, not incremental patching.
|
||||
2. Re-read the full issue body and all human comments; explicitly identify what will be done differently this attempt.
|
||||
3. Close the existing PR tied to the issue.
|
||||
4. Remove the existing `## Codex Workpad` comment from the issue.
|
||||
5. Create a fresh branch from `origin/main`.
|
||||
6. Start over from the normal kickoff flow:
|
||||
- If current issue state is `Todo`, move it to `In Progress`; otherwise keep the current state.
|
||||
- Create a new bootstrap `## Codex Workpad` comment.
|
||||
- Build a fresh plan/checklist and execute end-to-end.
|
||||
|
||||
## Completion bar before Human Review
|
||||
|
||||
- Step 1/2 checklist is fully complete and accurately reflected in the single workpad comment.
|
||||
- Acceptance criteria and required ticket-provided validation items are complete.
|
||||
- Validation/tests are green for the latest commit.
|
||||
- PR feedback sweep is complete and no actionable comments remain.
|
||||
- PR checks are green, branch is pushed, and PR is linked on the issue.
|
||||
- Required PR metadata is present (`symphony` label).
|
||||
- If app-touching, runtime validation/media requirements from `App runtime validation (required)` are complete.
|
||||
|
||||
## Guardrails
|
||||
|
||||
- If the branch PR is already closed/merged, do not reuse that branch or prior implementation state for continuation.
|
||||
- For closed/merged branch PRs, create a new branch from `origin/main` and restart from reproduction/planning as if starting fresh.
|
||||
- If issue state is `Backlog`, do not modify it; wait for human to move to `Todo`.
|
||||
- Do not edit the issue body/description for planning or progress tracking.
|
||||
- Use exactly one persistent workpad comment (`## Codex Workpad`) per issue.
|
||||
- If comment editing is unavailable in-session, use the update script. Only report blocked if both MCP editing and script-based editing are unavailable.
|
||||
- Temporary proof edits are allowed only for local verification and must be reverted before commit.
|
||||
- If out-of-scope improvements are found, create a separate Backlog issue rather
|
||||
than expanding current scope, and include a clear
|
||||
title/description/acceptance criteria, same-project assignment, a `related`
|
||||
link to the current issue, and `blockedBy` when the follow-up depends on the
|
||||
current issue.
|
||||
- Do not move to `Human Review` unless the `Completion bar before Human Review` is satisfied.
|
||||
- In `Human Review`, do not make changes; wait and poll.
|
||||
- If state is terminal (`Done`), do nothing and shut down.
|
||||
- Keep issue text concise, specific, and reviewer-oriented.
|
||||
- If blocked and no workpad exists yet, add one blocker comment describing blocker, impact, and next unblock action.
|
||||
|
||||
## Workpad template
|
||||
|
||||
Use this exact structure for the persistent workpad comment and keep it updated in place throughout execution:
|
||||
|
||||
````md
|
||||
## Codex Workpad
|
||||
|
||||
```text
|
||||
<hostname>:<abs-path>@<short-sha>
|
||||
```
|
||||
|
||||
### Plan
|
||||
|
||||
- [ ] 1\. Parent task
|
||||
- [ ] 1.1 Child task
|
||||
- [ ] 1.2 Child task
|
||||
- [ ] 2\. Parent task
|
||||
|
||||
### Acceptance Criteria
|
||||
|
||||
- [ ] Criterion 1
|
||||
- [ ] Criterion 2
|
||||
|
||||
### Validation
|
||||
|
||||
- [ ] targeted tests: `<command>`
|
||||
|
||||
### Notes
|
||||
|
||||
- <short progress note with timestamp>
|
||||
|
||||
### Confusions
|
||||
|
||||
- <only include when something was confusing during execution>
|
||||
````
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,6 +1,3 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1 +1,118 @@
|
||||
@AGENTS.md
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.platform_cost_log
|
||||
-- Looker source alias: ds115 | Charts: 0
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per platform cost log entry (last 90 days).
|
||||
-- Tracks real API spend at the call level: provider, model,
|
||||
-- token counts (including Anthropic cache tokens), cost in
|
||||
-- microdollars, and the block/execution that incurred the cost.
|
||||
-- Joins the User table to provide email for per-user breakdowns.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.PlatformCostLog — Per-call cost records
|
||||
-- platform.User — User email
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Log entry UUID
|
||||
-- createdAt TIMESTAMPTZ When the cost was recorded
|
||||
-- userId TEXT User who incurred the cost (nullable)
|
||||
-- email TEXT User email (nullable)
|
||||
-- graphExecId TEXT Graph execution UUID (nullable)
|
||||
-- nodeExecId TEXT Node execution UUID (nullable)
|
||||
-- blockName TEXT Block that made the API call (nullable)
|
||||
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
|
||||
-- model TEXT Model name (nullable)
|
||||
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
|
||||
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
|
||||
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
|
||||
-- inputTokens INT Prompt/input tokens (nullable)
|
||||
-- outputTokens INT Completion/output tokens (nullable)
|
||||
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
|
||||
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
|
||||
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
|
||||
-- duration FLOAT API call duration in seconds (nullable)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend by provider (last 90 days)
|
||||
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Spend by model
|
||||
-- SELECT provider, model, SUM("costUsd") AS total_usd,
|
||||
-- SUM("inputTokens") AS input_tokens,
|
||||
-- SUM("outputTokens") AS output_tokens
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE model IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Top 20 users by spend
|
||||
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE "userId" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
|
||||
--
|
||||
-- -- Daily spend trend
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("costUsd") AS daily_usd,
|
||||
-- COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("cacheReadTokens")::float /
|
||||
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE provider = 'anthropic'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
p."id" AS id,
|
||||
p."createdAt" AS createdAt,
|
||||
p."userId" AS userId,
|
||||
u."email" AS email,
|
||||
p."graphExecId" AS graphExecId,
|
||||
p."nodeExecId" AS nodeExecId,
|
||||
p."blockName" AS blockName,
|
||||
p."provider" AS provider,
|
||||
p."model" AS model,
|
||||
p."trackingType" AS trackingType,
|
||||
p."costMicrodollars" AS costMicrodollars,
|
||||
p."costMicrodollars"::float / 1000000.0 AS costUsd,
|
||||
p."inputTokens" AS inputTokens,
|
||||
p."outputTokens" AS outputTokens,
|
||||
p."cacheReadTokens" AS cacheReadTokens,
|
||||
p."cacheCreationTokens" AS cacheCreationTokens,
|
||||
CASE
|
||||
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
|
||||
THEN p."inputTokens" + p."outputTokens"
|
||||
ELSE NULL
|
||||
END AS totalTokens,
|
||||
p."duration" AS duration
|
||||
FROM platform."PlatformCostLog" p
|
||||
LEFT JOIN platform."User" u ON u."id" = p."userId"
|
||||
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -59,8 +59,6 @@ class OAuthState(BaseModel):
|
||||
code_verifier: Optional[str] = None
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
credential_id: Optional[str] = None
|
||||
"""If set, this OAuth flow upgrades an existing credential's scopes."""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.1.0"
|
||||
version = "7.0.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.7"
|
||||
version = "0.15.0"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -58,17 +58,6 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Graphiti Temporal Knowledge Graph Memory
|
||||
# Rollout controlled by LaunchDarkly flag "graphiti-memory"
|
||||
# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY.
|
||||
# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY.
|
||||
GRAPHITI_FALKORDB_HOST=localhost
|
||||
GRAPHITI_FALKORDB_PORT=6380
|
||||
GRAPHITI_FALKORDB_PASSWORD=
|
||||
GRAPHITI_LLM_MODEL=gpt-4.1-mini
|
||||
GRAPHITI_EMBEDDER_MODEL=text-embedding-3-small
|
||||
GRAPHITI_SEMAPHORE_LIMIT=5
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
@@ -179,9 +168,6 @@ MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
@@ -192,7 +178,6 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1 +1,226 @@
|
||||
@AGENTS.md
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
@@ -121,20 +121,36 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
# Install agent-browser (Copilot browser tool) + Chromium.
|
||||
# On amd64: install runtime libs + run `agent-browser install` to download
|
||||
# Chrome for Testing (pinned version, tested with Playwright).
|
||||
# On arm64: install system chromium package — Chrome for Testing has no ARM64
|
||||
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
|
||||
# script (below) to redirect agent-browser to the system binary.
|
||||
ARG TARGETARCH
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
apt-get install -y --no-install-recommends chromium fonts-liberation; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1; \
|
||||
fi \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
# On arm64 the system chromium is at /usr/bin/chromium; set
|
||||
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
|
||||
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
|
||||
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
|
||||
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
|
||||
> /usr/local/bin/entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
@@ -157,4 +173,5 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
{
|
||||
"id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"version": 2,
|
||||
"is_active": true,
|
||||
"name": "Calculator agent",
|
||||
"description": "",
|
||||
"instructions": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"user_id": "",
|
||||
"created_at": "2026-04-13T03:45:11.241Z",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"input_default": {
|
||||
"name": "Input",
|
||||
"secret": false,
|
||||
"advanced": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": -188.2244873046875,
|
||||
"y": 95
|
||||
}
|
||||
},
|
||||
"input_links": [],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
"input_default": {
|
||||
"name": "Output",
|
||||
"secret": false,
|
||||
"advanced": false,
|
||||
"escape_html": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 825.198974609375,
|
||||
"y": 123.75
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"output_links": [],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
"input_default": {
|
||||
"b": 34,
|
||||
"operation": "Add",
|
||||
"round_result": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 323.0255126953125,
|
||||
"y": 121.25
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
},
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"sub_graphs": [],
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Input": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Input"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Input"
|
||||
]
|
||||
},
|
||||
"output_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Output": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Output"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Output"
|
||||
]
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -18,22 +18,14 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
@@ -99,6 +91,18 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -469,12 +473,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -482,19 +486,28 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
) -> list[CredentialSummary]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -502,7 +515,16 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -575,11 +597,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to store credentials",
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -617,23 +639,15 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(creds.provider, provider):
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -1,932 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.models import User as AuthUser
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import (
|
||||
AgentDiagnosticsResponse,
|
||||
ExecutionDiagnosticsResponse,
|
||||
)
|
||||
from backend.data.diagnostics import (
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
cleanup_all_stuck_queued_executions,
|
||||
cleanup_orphaned_executions_bulk,
|
||||
cleanup_orphaned_schedules_bulk,
|
||||
get_agent_diagnostics,
|
||||
get_all_orphaned_execution_ids,
|
||||
get_all_schedules_details,
|
||||
get_all_stuck_queued_execution_ids,
|
||||
get_execution_diagnostics,
|
||||
get_failed_executions_count,
|
||||
get_failed_executions_details,
|
||||
get_invalid_executions_details,
|
||||
get_long_running_executions_details,
|
||||
get_orphaned_executions_details,
|
||||
get_orphaned_schedules_details,
|
||||
get_running_executions_details,
|
||||
get_schedule_health_metrics,
|
||||
get_stuck_queued_executions_details,
|
||||
stop_all_long_running_executions,
|
||||
)
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.executor.utils import add_graph_execution, stop_graph_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["diagnostics", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class RunningExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of running executions"""
|
||||
|
||||
executions: List[RunningExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class FailedExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of failed executions"""
|
||||
|
||||
executions: List[FailedExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class StopExecutionRequest(BaseModel):
|
||||
"""Request model for stopping a single execution"""
|
||||
|
||||
execution_id: str
|
||||
|
||||
|
||||
class StopExecutionsRequest(BaseModel):
|
||||
"""Request model for stopping multiple executions"""
|
||||
|
||||
execution_ids: List[str]
|
||||
|
||||
|
||||
class StopExecutionResponse(BaseModel):
|
||||
"""Response model for stop execution operations"""
|
||||
|
||||
success: bool
|
||||
stopped_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
class RequeueExecutionResponse(BaseModel):
|
||||
"""Response model for requeue execution operations"""
|
||||
|
||||
success: bool
|
||||
requeued_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions",
|
||||
response_model=ExecutionDiagnosticsResponse,
|
||||
summary="Get Execution Diagnostics",
|
||||
)
|
||||
async def get_execution_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about execution status.
|
||||
|
||||
Returns all execution metrics including:
|
||||
- Current state (running, queued)
|
||||
- Orphaned executions (>24h old, likely not in executor)
|
||||
- Failure metrics (1h, 24h, rate)
|
||||
- Long-running detection (stuck >1h, >24h)
|
||||
- Stuck queued detection
|
||||
- Throughput metrics (completions/hour)
|
||||
- RabbitMQ queue depths
|
||||
"""
|
||||
logger.info("Getting execution diagnostics")
|
||||
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
|
||||
response = ExecutionDiagnosticsResponse(
|
||||
running_executions=diagnostics.running_count,
|
||||
queued_executions_db=diagnostics.queued_db_count,
|
||||
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
|
||||
cancel_queue_depth=diagnostics.cancel_queue_depth,
|
||||
orphaned_running=diagnostics.orphaned_running,
|
||||
orphaned_queued=diagnostics.orphaned_queued,
|
||||
failed_count_1h=diagnostics.failed_count_1h,
|
||||
failed_count_24h=diagnostics.failed_count_24h,
|
||||
failure_rate_24h=diagnostics.failure_rate_24h,
|
||||
stuck_running_24h=diagnostics.stuck_running_24h,
|
||||
stuck_running_1h=diagnostics.stuck_running_1h,
|
||||
oldest_running_hours=diagnostics.oldest_running_hours,
|
||||
stuck_queued_1h=diagnostics.stuck_queued_1h,
|
||||
queued_never_started=diagnostics.queued_never_started,
|
||||
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
|
||||
invalid_running_without_start=diagnostics.invalid_running_without_start,
|
||||
completed_1h=diagnostics.completed_1h,
|
||||
completed_24h=diagnostics.completed_24h,
|
||||
throughput_per_hour=diagnostics.throughput_per_hour,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Execution diagnostics: running={diagnostics.running_count}, "
|
||||
f"queued_db={diagnostics.queued_db_count}, "
|
||||
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
|
||||
f"failed_24h={diagnostics.failed_count_24h}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/agents",
|
||||
response_model=AgentDiagnosticsResponse,
|
||||
summary="Get Agent Diagnostics",
|
||||
)
|
||||
async def get_agent_diagnostics_endpoint():
|
||||
"""
|
||||
Get diagnostic information about agents.
|
||||
|
||||
Returns:
|
||||
- agents_with_active_executions: Number of unique agents with running/queued executions
|
||||
- timestamp: Current timestamp
|
||||
"""
|
||||
logger.info("Getting agent diagnostics")
|
||||
|
||||
diagnostics = await get_agent_diagnostics()
|
||||
|
||||
response = AgentDiagnosticsResponse(
|
||||
agents_with_active_executions=diagnostics.agents_with_active_executions,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Running Executions",
|
||||
)
|
||||
async def list_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of running and queued executions (recent, likely active).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of running executions with details
|
||||
"""
|
||||
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.running_count + diagnostics.queued_db_count
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/orphaned",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Orphaned Executions",
|
||||
)
|
||||
async def list_orphaned_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of orphaned executions (>24h old, likely not in executor).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of orphaned executions with details
|
||||
"""
|
||||
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/failed",
|
||||
response_model=FailedExecutionsListResponse,
|
||||
summary="List Failed Executions",
|
||||
)
|
||||
async def list_failed_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
hours: int = 24,
|
||||
):
|
||||
"""
|
||||
Get detailed list of failed executions.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
hours: Number of hours to look back (default 24)
|
||||
|
||||
Returns:
|
||||
List of failed executions with error details
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
|
||||
)
|
||||
|
||||
executions = await get_failed_executions_details(
|
||||
limit=limit, offset=offset, hours=hours
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
# Always count actual total for given hours parameter
|
||||
total = await get_failed_executions_count(hours=hours)
|
||||
|
||||
return FailedExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/long-running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Long-Running Executions",
|
||||
)
|
||||
async def list_long_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of long-running executions (RUNNING status >24h).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of long-running executions with details
|
||||
"""
|
||||
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_long_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_running_24h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/stuck-queued",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Stuck Queued Executions",
|
||||
)
|
||||
async def list_stuck_queued_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of stuck queued executions (QUEUED >1h, never started).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of stuck queued executions with details
|
||||
"""
|
||||
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_queued_1h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/invalid",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Invalid Executions",
|
||||
)
|
||||
async def list_invalid_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of executions in invalid states (READ-ONLY).
|
||||
|
||||
Invalid states indicate data corruption and require manual investigation:
|
||||
- QUEUED but has startedAt (impossible - can't start while queued)
|
||||
- RUNNING but no startedAt (impossible - can't run without starting)
|
||||
|
||||
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
|
||||
|
||||
Each invalid execution likely has a different root cause (crashes, race conditions,
|
||||
DB corruption). Investigate the execution history and logs to determine appropriate
|
||||
action (manual cleanup, status fix, or leave as-is if system recovered).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of invalid state executions with details
|
||||
"""
|
||||
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_invalid_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = (
|
||||
diagnostics.invalid_queued_with_start
|
||||
+ diagnostics.invalid_running_without_start
|
||||
)
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Stuck Execution",
|
||||
)
|
||||
async def requeue_single_execution(
|
||||
request: StopExecutionRequest, # Reuse same request model (has execution_id)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue a stuck QUEUED execution (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to requeue
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
|
||||
|
||||
# Get the execution (validation - must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Execution not found or not in QUEUED status",
|
||||
)
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use add_graph_execution in requeue mode
|
||||
await add_graph_execution(
|
||||
graph_id=execution.graph_id,
|
||||
user_id=execution.user_id,
|
||||
graph_version=execution.graph_version,
|
||||
graph_exec_id=request.execution_id, # Requeue existing execution
|
||||
)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=1,
|
||||
message="Execution requeued successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-bulk",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Multiple Stuck Executions",
|
||||
)
|
||||
async def requeue_multiple_executions(
|
||||
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue multiple stuck QUEUED executions (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to requeue
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return RequeueExecutionResponse(
|
||||
success=False,
|
||||
requeued_count=0,
|
||||
message="No QUEUED executions found to requeue",
|
||||
)
|
||||
|
||||
# Requeue all executions in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Single Execution",
|
||||
)
|
||||
async def stop_single_execution(
|
||||
request: StopExecutionRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop a single execution (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to stop
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
|
||||
|
||||
# Get the execution to find its owner user_id (required by stop_graph_execution)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use robust stop_graph_execution (cascades to children, waits for termination)
|
||||
await stop_graph_execution(
|
||||
user_id=execution.user_id,
|
||||
graph_exec_id=request.execution_id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=1,
|
||||
message="Execution stopped successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-bulk",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Multiple Executions",
|
||||
)
|
||||
async def stop_multiple_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop multiple active executions (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to stop
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return StopExecutionResponse(
|
||||
success=False,
|
||||
stopped_count=0,
|
||||
message="No executions found",
|
||||
)
|
||||
|
||||
# Stop all executions in parallel using robust stop_graph_execution
|
||||
async def stop_one(exec) -> bool:
|
||||
try:
|
||||
await stop_graph_execution(
|
||||
user_id=exec.user_id,
|
||||
graph_exec_id=exec.id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[stop_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
stopped_count = sum(1 for success in results if success)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup Orphaned Executions",
|
||||
)
|
||||
async def cleanup_orphaned_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned executions by directly updating DB status (admin only).
|
||||
For executions in DB but not actually running in executor (old/stale records).
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to cleanup
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(
|
||||
request.execution_ids, user.user_id
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULE DIAGNOSTICS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SchedulesListResponse(BaseModel):
|
||||
"""Response model for list of schedules"""
|
||||
|
||||
schedules: List[ScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class OrphanedSchedulesListResponse(BaseModel):
|
||||
"""Response model for list of orphaned schedules"""
|
||||
|
||||
schedules: List[OrphanedScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class ScheduleCleanupRequest(BaseModel):
|
||||
"""Request model for cleaning up schedules"""
|
||||
|
||||
schedule_ids: List[str]
|
||||
|
||||
|
||||
class ScheduleCleanupResponse(BaseModel):
|
||||
"""Response model for schedule cleanup operations"""
|
||||
|
||||
success: bool
|
||||
deleted_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules",
|
||||
response_model=ScheduleHealthMetrics,
|
||||
summary="Get Schedule Diagnostics",
|
||||
)
|
||||
async def get_schedule_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about schedule health.
|
||||
|
||||
Returns schedule metrics including:
|
||||
- Total schedules (user vs system)
|
||||
- Orphaned schedules by category
|
||||
- Upcoming executions
|
||||
"""
|
||||
logger.info("Getting schedule diagnostics")
|
||||
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
|
||||
logger.info(
|
||||
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
|
||||
f"user={diagnostics.user_schedules}, "
|
||||
f"orphaned={diagnostics.total_orphaned}"
|
||||
)
|
||||
|
||||
return diagnostics
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/all",
|
||||
response_model=SchedulesListResponse,
|
||||
summary="List All User Schedules",
|
||||
)
|
||||
async def list_all_schedules(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of all user schedules (excludes system monitoring jobs).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of schedules to return (default 100)
|
||||
offset: Number of schedules to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of schedules with details
|
||||
"""
|
||||
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
|
||||
|
||||
schedules = await get_all_schedules_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
total = diagnostics.user_schedules
|
||||
|
||||
return SchedulesListResponse(schedules=schedules, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/orphaned",
|
||||
response_model=OrphanedSchedulesListResponse,
|
||||
summary="List Orphaned Schedules",
|
||||
)
|
||||
async def list_orphaned_schedules():
|
||||
"""
|
||||
Get detailed list of orphaned schedules with orphan reasons.
|
||||
|
||||
Returns:
|
||||
List of orphaned schedules categorized by orphan type
|
||||
"""
|
||||
logger.info("Listing orphaned schedules")
|
||||
|
||||
schedules = await get_orphaned_schedules_details()
|
||||
|
||||
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/schedules/cleanup-orphaned",
|
||||
response_model=ScheduleCleanupResponse,
|
||||
summary="Cleanup Orphaned Schedules",
|
||||
)
|
||||
async def cleanup_orphaned_schedules(
|
||||
request: ScheduleCleanupRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned schedules by deleting from scheduler (admin only).
|
||||
|
||||
Args:
|
||||
request: Contains list of schedule_ids to delete
|
||||
|
||||
Returns:
|
||||
Number of schedules deleted and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
|
||||
)
|
||||
|
||||
deleted_count = await cleanup_orphaned_schedules_bulk(
|
||||
request.schedule_ids, user.user_id
|
||||
)
|
||||
|
||||
return ScheduleCleanupResponse(
|
||||
success=deleted_count > 0,
|
||||
deleted_count=deleted_count,
|
||||
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-all-long-running",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop ALL Long-Running Executions",
|
||||
)
|
||||
async def stop_all_long_running_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
|
||||
|
||||
stopped_count = await stop_all_long_running_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} long-running executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Orphaned Executions",
|
||||
)
|
||||
async def cleanup_all_orphaned_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
|
||||
|
||||
# Fetch all orphaned execution IDs
|
||||
execution_ids = await get_all_orphaned_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=0,
|
||||
message="No orphaned executions to cleanup",
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-stuck-queued",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Stuck Queued Executions",
|
||||
)
|
||||
async def cleanup_all_stuck_queued_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
|
||||
|
||||
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} stuck queued executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-all-stuck",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue ALL Stuck Queued Executions",
|
||||
)
|
||||
async def requeue_all_stuck_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
|
||||
|
||||
# Fetch all stuck queued execution IDs
|
||||
execution_ids = await get_all_stuck_queued_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=0,
|
||||
message="No stuck queued executions to requeue",
|
||||
)
|
||||
|
||||
# Get stuck executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
# Requeue all in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} stuck executions",
|
||||
)
|
||||
@@ -1,889 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
|
||||
from backend.data.diagnostics import (
|
||||
AgentDiagnosticsSummary,
|
||||
ExecutionDiagnosticsSummary,
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
)
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(diagnostics_admin_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_execution_diagnostics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test fetching execution diagnostics with invalid state detection"""
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=1,
|
||||
stuck_running_1h=3,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1, # New invalid state
|
||||
invalid_running_without_start=1, # New invalid state
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify new invalid state fields are included
|
||||
assert data["invalid_queued_with_start"] == 1
|
||||
assert data["invalid_running_without_start"] == 1
|
||||
# Verify all expected fields present
|
||||
assert "running_executions" in data
|
||||
assert "orphaned_running" in data
|
||||
assert "failed_count_24h" in data
|
||||
|
||||
|
||||
def test_list_invalid_executions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test listing executions in invalid states (read-only endpoint)"""
|
||||
mock_invalid_executions = [
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-1",
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="QUEUED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(
|
||||
timezone.utc
|
||||
), # QUEUED but has startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-2",
|
||||
graph_id="graph-456",
|
||||
graph_name="Another Graph",
|
||||
graph_version=2,
|
||||
user_id="user-456",
|
||||
user_email="user@example.com",
|
||||
status="RUNNING",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=None, # RUNNING but no startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
]
|
||||
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=0,
|
||||
orphaned_queued=0,
|
||||
failed_count_1h=0,
|
||||
failed_count_24h=0,
|
||||
failure_rate_24h=0.0,
|
||||
stuck_running_24h=0,
|
||||
stuck_running_1h=0,
|
||||
oldest_running_hours=None,
|
||||
stuck_queued_1h=0,
|
||||
queued_never_started=0,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=0,
|
||||
completed_24h=0,
|
||||
throughput_per_hour=0.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
|
||||
return_value=mock_invalid_executions,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # Sum of both invalid state types
|
||||
assert len(data["executions"]) == 2
|
||||
# Verify both types of invalid states are returned
|
||||
assert data["executions"][0]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
assert data["executions"][1]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
|
||||
|
||||
def test_requeue_single_execution_with_add_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test requeueing uses add_graph_execution in requeue mode"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-stuck-123",
|
||||
user_id="user-123",
|
||||
graph_id="graph-456",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_add_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-stuck-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 1
|
||||
|
||||
# Verify it used add_graph_execution in requeue mode
|
||||
mock_add_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_add_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
|
||||
assert call_kwargs["graph_id"] == "graph-456"
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
|
||||
|
||||
def test_stop_single_execution_with_stop_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test stopping uses robust stop_graph_execution"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-running-123",
|
||||
user_id="user-789",
|
||||
graph_id="graph-999",
|
||||
graph_version=2,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_stop_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 1
|
||||
|
||||
# Verify it used stop_graph_execution with cascade
|
||||
mock_stop_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_stop_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-running-123"
|
||||
assert call_kwargs["user_id"] == "user-789"
|
||||
assert call_kwargs["cascade"] is True # Stops children too!
|
||||
assert call_kwargs["wait_timeout"] == 15.0
|
||||
|
||||
|
||||
def test_requeue_not_queued_execution_fails(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that requeue fails if execution is not in QUEUED status"""
|
||||
# Mock an execution that's RUNNING (not QUEUED)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[], # No QUEUED executions found
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found or not in QUEUED status" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_list_invalid_executions_no_bulk_actions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
|
||||
# This is a documentation test - the endpoint exists but should not
|
||||
# have corresponding cleanup/stop/requeue endpoints
|
||||
|
||||
# These endpoints should NOT exist for invalid states:
|
||||
invalid_bulk_endpoints = [
|
||||
"/admin/diagnostics/executions/cleanup-invalid",
|
||||
"/admin/diagnostics/executions/stop-invalid",
|
||||
"/admin/diagnostics/executions/requeue-invalid",
|
||||
]
|
||||
|
||||
for endpoint in invalid_bulk_endpoints:
|
||||
response = client.post(endpoint, json={"execution_ids": ["test"]})
|
||||
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
|
||||
|
||||
|
||||
def test_execution_ids_filter_efficiency(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that bulk operations use efficient execution_ids filter"""
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
mock_get_graph_executions = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify it used execution_ids filter (not fetching all queued)
|
||||
mock_get_graph_executions.assert_called_once()
|
||||
call_kwargs = mock_get_graph_executions.call_args.kwargs
|
||||
assert "execution_ids" in call_kwargs
|
||||
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
|
||||
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: reusable mock diagnostics summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
|
||||
defaults = dict(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=3,
|
||||
stuck_running_1h=5,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ExecutionDiagnosticsSummary(**defaults)
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _make_mock_execution(
|
||||
exec_id: str = "exec-1",
|
||||
status: str = "RUNNING",
|
||||
started_at: datetime | None | object = _SENTINEL,
|
||||
) -> RunningExecutionDetail:
|
||||
return RunningExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status=status,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=(
|
||||
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
|
||||
),
|
||||
queue_status=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_failed_execution(
|
||||
exec_id: str = "exec-fail-1",
|
||||
) -> FailedExecutionDetail:
|
||||
return FailedExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="FAILED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
failed_at=datetime.now(timezone.utc),
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
|
||||
defaults = dict(
|
||||
total_schedules=15,
|
||||
user_schedules=10,
|
||||
system_schedules=5,
|
||||
orphaned_deleted_graph=2,
|
||||
orphaned_no_library_access=1,
|
||||
orphaned_invalid_credentials=0,
|
||||
orphaned_validation_failed=0,
|
||||
total_orphaned=3,
|
||||
schedules_next_hour=4,
|
||||
schedules_next_24h=8,
|
||||
total_runs_next_hour=12,
|
||||
total_runs_next_24h=48,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ScheduleHealthMetrics(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: execution list variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-run-1"),
|
||||
_make_mock_execution("exec-run-2"),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
|
||||
assert len(data["executions"]) == 2
|
||||
assert data["executions"][0]["execution_id"] == "exec-run-1"
|
||||
|
||||
|
||||
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
|
||||
return_value=42,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 42
|
||||
assert len(data["executions"]) == 1
|
||||
assert data["executions"][0]["error_message"] == "Something went wrong"
|
||||
|
||||
|
||||
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-long-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # stuck_running_24h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # stuck_queued_1h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: agent + schedule diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_diag = AgentDiagnosticsSummary(
|
||||
agents_with_active_executions=7,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
|
||||
return_value=mock_diag,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/agents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agents_with_active_executions"] == 7
|
||||
|
||||
|
||||
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_metrics = _make_mock_schedule_health()
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=mock_metrics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_schedules"] == 10
|
||||
assert data["total_orphaned"] == 3
|
||||
assert data["total_runs_next_hour"] == 12
|
||||
|
||||
|
||||
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_schedules = [
|
||||
ScheduleDetail(
|
||||
schedule_id="sched-1",
|
||||
schedule_name="Daily Run",
|
||||
graph_id="graph-1",
|
||||
graph_name="My Agent",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
user_email="alice@example.com",
|
||||
cron="0 9 * * *",
|
||||
timezone="UTC",
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
|
||||
return_value=mock_schedules,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=_make_mock_schedule_health(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 10
|
||||
assert len(data["schedules"]) == 1
|
||||
assert data["schedules"][0]["schedule_name"] == "Daily Run"
|
||||
|
||||
|
||||
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_orphans = [
|
||||
OrphanedScheduleDetail(
|
||||
schedule_id="sched-orphan-1",
|
||||
schedule_name="Ghost Schedule",
|
||||
graph_id="graph-deleted",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
orphan_reason="deleted_graph",
|
||||
error_detail=None,
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
|
||||
return_value=mock_orphans,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST endpoints: bulk stop, cleanup, requeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["stopped_count"] == 0
|
||||
|
||||
|
||||
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=3,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/cleanup-orphaned",
|
||||
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 3
|
||||
|
||||
|
||||
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/schedules/cleanup-orphaned",
|
||||
json={"schedule_ids": ["sched-1", "sched-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 2
|
||||
|
||||
|
||||
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
|
||||
return_value=5,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 5
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=["exec-1", "exec-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 0
|
||||
assert "No orphaned" in data["message"]
|
||||
|
||||
|
||||
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
|
||||
return_value=4,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 4
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-stuck-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=None,
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 3
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 0
|
||||
assert "No stuck" in data["message"]
|
||||
|
||||
|
||||
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["requeued_count"] == 0
|
||||
|
||||
|
||||
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "nonexistent"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
@@ -14,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class ExecutionDiagnosticsResponse(BaseModel):
|
||||
"""Response model for execution diagnostics"""
|
||||
|
||||
# Current execution state
|
||||
running_executions: int
|
||||
queued_executions_db: int
|
||||
queued_executions_rabbitmq: int
|
||||
cancel_queue_depth: int
|
||||
|
||||
# Orphaned execution detection
|
||||
orphaned_running: int
|
||||
orphaned_queued: int
|
||||
|
||||
# Failure metrics
|
||||
failed_count_1h: int
|
||||
failed_count_24h: int
|
||||
failure_rate_24h: float
|
||||
|
||||
# Long-running detection
|
||||
stuck_running_24h: int
|
||||
stuck_running_1h: int
|
||||
oldest_running_hours: float | None
|
||||
|
||||
# Stuck queued detection
|
||||
stuck_queued_1h: int
|
||||
queued_never_started: int
|
||||
|
||||
# Invalid state detection (data corruption - no auto-actions)
|
||||
invalid_queued_with_start: int
|
||||
invalid_running_without_start: int
|
||||
|
||||
# Throughput metrics
|
||||
completed_1h: int
|
||||
completed_24h: int
|
||||
throughput_per_hour: float
|
||||
|
||||
timestamp: str
|
||||
|
||||
|
||||
class AgentDiagnosticsResponse(BaseModel):
|
||||
"""Response model for agent diagnostics"""
|
||||
|
||||
agents_with_active_executions: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
class ScheduleHealthMetrics(BaseModel):
|
||||
"""Response model for schedule diagnostics"""
|
||||
|
||||
total_schedules: int
|
||||
user_schedules: int
|
||||
system_schedules: int
|
||||
|
||||
# Orphan detection
|
||||
orphaned_deleted_graph: int
|
||||
orphaned_no_library_access: int
|
||||
orphaned_invalid_credentials: int
|
||||
orphaned_validation_failed: int
|
||||
total_orphaned: int
|
||||
|
||||
# Upcoming
|
||||
schedules_next_hour: int
|
||||
schedules_next_24h: int
|
||||
|
||||
timestamp: str
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
get_platform_cost_logs_for_export,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
return await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostExportResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
total_rows: int
|
||||
truncated: bool
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs/export",
|
||||
response_model=PlatformCostExportResponse,
|
||||
summary="Export Platform Cost Logs",
|
||||
)
|
||||
async def export_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s exporting platform cost logs", admin_user_id)
|
||||
logs, truncated = await get_platform_cost_logs_for_export(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
return PlatformCostExportResponse(
|
||||
logs=logs,
|
||||
total_rows=len(logs),
|
||||
truncated=truncated,
|
||||
)
|
||||
@@ -1,291 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import CostLogRow, PlatformCostDashboard
|
||||
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_repeated_requests(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Repeated requests to the dashboard route both return 200."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
r1 = client.get("/platform-costs/dashboard")
|
||||
r2 = client.get("/platform-costs/dashboard")
|
||||
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r1.json()["total_cost_microdollars"] == 42
|
||||
assert r2.json()["total_cost_microdollars"] == 42
|
||||
|
||||
|
||||
def _make_cost_log_row() -> CostLogRow:
|
||||
return CostLogRow(
|
||||
id="log-1",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
user_id="user-1",
|
||||
email="u***@example.com",
|
||||
graph_exec_id="graph-1",
|
||||
node_exec_id="node-1",
|
||||
block_name="LlmCallBlock",
|
||||
provider="anthropic",
|
||||
tracking_type="token",
|
||||
cost_microdollars=500,
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cache_read_tokens=10,
|
||||
cache_creation_tokens=5,
|
||||
duration=1.5,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
)
|
||||
|
||||
|
||||
def test_export_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
row = _make_cost_log_row()
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
AsyncMock(return_value=([row], False)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_rows"] == 1
|
||||
assert data["truncated"] is False
|
||||
assert len(data["logs"]) == 1
|
||||
assert data["logs"][0]["cache_read_tokens"] == 10
|
||||
assert data["logs"][0]["cache_creation_tokens"] == 5
|
||||
|
||||
|
||||
def test_export_logs_truncated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
rows = [_make_cost_log_row() for _ in range(3)]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
AsyncMock(return_value=(rows, True)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_rows"] == 3
|
||||
assert data["truncated"] is True
|
||||
|
||||
|
||||
def test_export_logs_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_export = AsyncMock(return_value=([], False))
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export",
|
||||
mock_export,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs/export",
|
||||
params={
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"block_name": "LlmCallBlock",
|
||||
"tracking_type": "token",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_export.assert_called_once()
|
||||
call_kwargs = mock_export.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "anthropic"
|
||||
assert call_kwargs["model"] == "claude-3-5-sonnet-20241022"
|
||||
assert call_kwargs["block_name"] == "LlmCallBlock"
|
||||
assert call_kwargs["tracking_type"] == "token"
|
||||
|
||||
|
||||
def test_export_logs_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/logs/export")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
@@ -1,263 +0,0 @@
|
||||
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["copilot", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_cost_limit_microdollars: int
|
||||
weekly_cost_limit_microdollars: int
|
||||
daily_cost_used_microdollars: int
|
||||
weekly_cost_used_microdollars: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
user_id: Optional[str], email: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Resolve a user_id and email from the provided parameters.
|
||||
|
||||
Returns (user_id, email). Accepts either user_id or email; at least one
|
||||
must be provided. When both are provided, ``email`` takes precedence.
|
||||
"""
|
||||
if email:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No user found with the provided email."
|
||||
)
|
||||
return user.id, email
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either user_id or email query parameter is required.",
|
||||
)
|
||||
|
||||
# We have a user_id; try to look up their email for display purposes.
|
||||
# This is non-critical -- a failure should not block the response.
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
return user_id, resolved_email
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Get User Rate Limit",
|
||||
)
|
||||
async def get_user_rate_limit(
|
||||
user_id: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Get a user's current usage and effective rate limits. Admin-only.
|
||||
|
||||
Accepts either ``user_id`` or ``email`` as a query parameter.
|
||||
When ``email`` is provided the user is looked up by email first.
|
||||
"""
|
||||
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/reset",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Reset User Rate Limit Usage",
|
||||
)
|
||||
async def reset_user_rate_limit(
|
||||
user_id: str = Body(embed=True),
|
||||
reset_weekly: bool = Body(False, embed=True),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
|
||||
logger.info(
|
||||
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
|
||||
admin_user_id,
|
||||
user_id,
|
||||
reset_weekly,
|
||||
)
|
||||
|
||||
try:
|
||||
await reset_user_usage(user_id, reset_weekly=reset_weekly)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s",
|
||||
request.user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved_email = None
|
||||
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
|
||||
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email,
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
@@ -1,566 +0,0 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(rate_limit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
|
||||
|
||||
_TARGET_EMAIL = "target@example.com"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_usage_status(
|
||||
daily_used: int = 500_000, weekly_used: int = 3_000_000
|
||||
) -> CoPilotUsageStatus:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _patch_rate_limit_deps(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
daily_used: int = 500_000,
|
||||
weekly_used: int = 3_000_000,
|
||||
):
|
||||
"""Patch the common rate-limit + user-lookup dependencies."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting rate limit and usage for a user."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
assert data["weekly_cost_limit_microdollars"] == 12_500_000
|
||||
assert data["daily_cost_used_microdollars"] == 500_000
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"get_rate_limit",
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test looking up rate limits via email instead of user_id."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test that looking up a non-existent email returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_rate_limit_no_params() -> None:
|
||||
"""Test that omitting both user_id and email returns 400."""
|
||||
response = client.get("/admin/rate_limit")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting only daily usage (default behaviour)."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_only",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting both daily and weekly usage."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id, "reset_weekly": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
assert data["weekly_cost_used_microdollars"] == 0
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_and_weekly",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that Redis failure on reset returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Redis connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_rate_limit_email_lookup_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that failing to resolve a user email degrades gracefully."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection lost"),
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] is None
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that rate limit admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to BASIC."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "BASIC"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "BASIC"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure returns 404 (user unverifiable)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
@@ -7,8 +7,6 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -134,40 +132,3 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,13 +0,0 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from fastapi import (
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
@@ -29,32 +29,21 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
is_sdk_default,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import (
|
||||
ensure_managed_credential,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
|
||||
from backend.integrations.managed_providers.ayrshare import (
|
||||
settings_available as ayrshare_settings_available,
|
||||
)
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -93,23 +82,14 @@ async def login(
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
credential_id: Annotated[
|
||||
str | None,
|
||||
Query(title="ID of existing credential to upgrade scopes for"),
|
||||
] = None,
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
if credential_id:
|
||||
requested_scopes = await _prepare_scope_upgrade(
|
||||
user_id, provider, credential_id, requested_scopes
|
||||
)
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes, credential_id=credential_id
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
requested_scopes, state_token, code_challenge=code_challenge
|
||||
@@ -129,7 +109,6 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -159,19 +138,6 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -231,54 +197,41 @@ async def callback(
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
credentials = await _merge_or_create_credential(
|
||||
user_id, provider, credentials, valid_state.credential_id
|
||||
)
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
|
||||
# the credential-list endpoint. On timeout we still kick off a fire-and-
|
||||
# forget sweep so provisioning eventually completes; the user just won't
|
||||
# see the managed cred until the next refresh.
|
||||
_MANAGED_PROVISION_TIMEOUT_S = 10.0
|
||||
|
||||
|
||||
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ensure_managed_credentials(user_id, creds_manager.store),
|
||||
timeout=_MANAGED_PROVISION_TIMEOUT_S,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Managed credential sweep exceeded %.1fs for user=%s; "
|
||||
"continuing without it — provisioning will complete in background",
|
||||
_MANAGED_PROVISION_TIMEOUT_S,
|
||||
user_id,
|
||||
)
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Block on provisioning so managed credentials appear on the first load
|
||||
# instead of after a refresh, but with a timeout so a slow upstream
|
||||
# can't hang the endpoint. `_provisioned_users` short-circuits on
|
||||
# repeat calls.
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -289,11 +242,19 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@@ -306,130 +267,18 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(credential.provider, provider):
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
class PickerTokenResponse(BaseModel):
|
||||
"""Short-lived OAuth access token shipped to the browser for rendering a
|
||||
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
|
||||
only the fields the client needs to initialize the picker widget. Issued
|
||||
from the user's own stored credential so ownership and scope gating are
|
||||
enforced by the credential lookup."""
|
||||
|
||||
access_token: str = Field(
|
||||
description="OAuth access token suitable for the picker SDK call."
|
||||
)
|
||||
access_token_expires_at: int | None = Field(
|
||||
default=None,
|
||||
description="Unix timestamp at which the access token expires, if known.",
|
||||
)
|
||||
|
||||
|
||||
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
|
||||
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
|
||||
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
|
||||
# future provider integrates a hosted picker that needs a raw access token,
|
||||
# add its specific picker-relevant scopes here.
|
||||
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
|
||||
ProviderName.GOOGLE: frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/credentials/{cred_id}/picker-token",
|
||||
summary="Issue a short-lived access token for a provider-hosted picker",
|
||||
operation_id="postV1GetPickerToken",
|
||||
)
|
||||
async def get_picker_token(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider that owns the credentials")
|
||||
],
|
||||
cred_id: Annotated[
|
||||
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> PickerTokenResponse:
|
||||
"""Return the raw access token for an OAuth2 credential so the frontend
|
||||
can initialize a provider-hosted picker (e.g. Google Drive Picker).
|
||||
|
||||
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
|
||||
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
|
||||
`router_test.py`). That hardening broke the Drive picker, which needs the
|
||||
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
|
||||
endpoint carves a narrow, explicit hole: the caller must own the
|
||||
credential, it must be OAuth2, and the endpoint returns only the access
|
||||
token + its expiry — nothing else about the credential. SDK-default
|
||||
credentials are excluded for the same reason as `get_credential`.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not isinstance(credential, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Picker tokens are only available for OAuth2 credentials",
|
||||
)
|
||||
if not credential.access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential has no access token; reconnect the account",
|
||||
)
|
||||
|
||||
# Gate on provider+scope: only credentials that actually grant access to
|
||||
# a provider-hosted picker flow may mint a token through this endpoint.
|
||||
# Prevents using this path to extract bearer tokens for unrelated OAuth
|
||||
# integrations (e.g. GitHub) that happen to be stored under the same user.
|
||||
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
|
||||
if not allowed_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
|
||||
)
|
||||
cred_scopes = set(credential.scopes or [])
|
||||
if cred_scopes.isdisjoint(allowed_scopes):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Credential does not grant any scope eligible for the picker. "
|
||||
"Reconnect with the appropriate scope."
|
||||
),
|
||||
)
|
||||
|
||||
return PickerTokenResponse(
|
||||
access_token=credential.access_token.get_secret_value(),
|
||||
access_token_expires_at=credential.access_token_expires_at,
|
||||
)
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -439,22 +288,16 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
) -> Credentials:
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to store credentials",
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -489,29 +332,15 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(creds.provider, provider):
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -725,186 +554,6 @@ async def _execute_webhook_preset_trigger(
|
||||
# Continue processing - webhook should be resilient to individual failures
|
||||
|
||||
|
||||
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
|
||||
|
||||
|
||||
async def _prepare_scope_upgrade(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credential_id: str,
|
||||
requested_scopes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate an existing credential for scope upgrade and compute scopes.
|
||||
|
||||
For providers without native incremental auth (e.g. GitHub), returns the
|
||||
union of existing + requested scopes. For providers that handle merging
|
||||
server-side (e.g. Google with ``include_granted_scopes``), returns the
|
||||
requested scopes unchanged.
|
||||
|
||||
Raises HTTPException on validation failure.
|
||||
"""
|
||||
# Platform-owned system credentials must never be upgraded — scope
|
||||
# changes here would leak across every user that shares them.
|
||||
if is_system_credential(credential_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only OAuth2 credentials can be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, provider.value):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
# Google handles scope merging via include_granted_scopes; others need
|
||||
# the union of existing + new scopes in the login URL.
|
||||
if provider != ProviderName.GOOGLE:
|
||||
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
|
||||
|
||||
return requested_scopes
|
||||
|
||||
|
||||
async def _merge_or_create_credential(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: OAuth2Credentials,
|
||||
credential_id: str | None,
|
||||
) -> OAuth2Credentials:
|
||||
"""Either upgrade an existing credential or create a new one.
|
||||
|
||||
When *credential_id* is set (explicit upgrade), merges scopes and updates
|
||||
the existing credential. Otherwise, checks for an implicit merge (same
|
||||
provider + username) before falling back to creating a new credential.
|
||||
"""
|
||||
if credential_id:
|
||||
return await _upgrade_existing_credential(user_id, credential_id, credentials)
|
||||
|
||||
# Implicit merge: check for existing credential with same provider+username.
|
||||
# Skip managed/system credentials and require a non-None username on both
|
||||
# sides so we never accidentally merge unrelated credentials.
|
||||
if credentials.username is None:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
matching = next(
|
||||
(
|
||||
c
|
||||
for c in existing_creds
|
||||
if isinstance(c, OAuth2Credentials)
|
||||
and not c.is_managed
|
||||
and not is_system_credential(c.id)
|
||||
and c.username is not None
|
||||
and c.username == credentials.username
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching:
|
||||
# Only merge into the existing credential when the new token
|
||||
# already covers every scope we're about to advertise on it.
|
||||
# Without this guard we'd overwrite ``matching.access_token`` with
|
||||
# a narrower token while storing a wider ``scopes`` list — the
|
||||
# record would claim authorizations the token does not grant, and
|
||||
# blocks using the lost scopes would fail with opaque 401/403s
|
||||
# until the user hits re-auth. On a narrowing login, keep the
|
||||
# two credentials separate instead.
|
||||
if set(credentials.scopes).issuperset(set(matching.scopes)):
|
||||
return await _upgrade_existing_credential(user_id, matching.id, credentials)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
async def _upgrade_existing_credential(
|
||||
user_id: str,
|
||||
existing_cred_id: str,
|
||||
new_credentials: OAuth2Credentials,
|
||||
) -> OAuth2Credentials:
|
||||
"""Merge scopes from *new_credentials* into an existing credential."""
|
||||
# Defense-in-depth: re-check system and provider invariants right before
|
||||
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
|
||||
# by the time the callback runs, and the implicit-merge path bypasses
|
||||
# login-time validation entirely, so every write-path must enforce these
|
||||
# on its own.
|
||||
if is_system_credential(existing_cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
|
||||
if not existing or not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, new_credentials.provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
|
||||
if (
|
||||
existing.username
|
||||
and new_credentials.username
|
||||
and existing.username != new_credentials.username
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username mismatch: authenticated as a different user",
|
||||
)
|
||||
|
||||
# Operate on a copy so the caller's ``new_credentials`` object is not
|
||||
# mutated out from under them. Every caller today immediately discards
|
||||
# or replaces its reference, but the implicit-merge path in
|
||||
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
|
||||
# calling into us — a future reader after the call would otherwise
|
||||
# silently see the overwritten values.
|
||||
merged = new_credentials.model_copy(deep=True)
|
||||
merged.id = existing.id
|
||||
merged.title = existing.title
|
||||
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
|
||||
merged.metadata = {
|
||||
**(existing.metadata or {}),
|
||||
**(new_credentials.metadata or {}),
|
||||
}
|
||||
# Preserve the existing refresh_token and username if the incremental
|
||||
# response doesn't carry them. Providers like Google only return a
|
||||
# refresh_token on first authorization — dropping it here would orphan
|
||||
# the credential on the next access-token expiry, forcing the user to
|
||||
# re-auth from scratch. Username is similarly sticky: if we've already
|
||||
# resolved it for this credential, keep it rather than silently
|
||||
# blanking it on an incremental upgrade.
|
||||
if not merged.refresh_token and existing.refresh_token:
|
||||
merged.refresh_token = existing.refresh_token
|
||||
merged.refresh_token_expires_at = existing.refresh_token_expires_at
|
||||
if not merged.username and existing.username:
|
||||
merged.username = existing.username
|
||||
await creds_manager.update(user_id, merged)
|
||||
return merged
|
||||
|
||||
|
||||
# --------------------------- UTILITIES ---------------------------- #
|
||||
|
||||
|
||||
@@ -1115,21 +764,12 @@ def _get_provider_oauth_handler(
|
||||
async def get_ayrshare_sso_url(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> AyrshareSSOResponse:
|
||||
"""Generate a JWT SSO URL so the user can link their social accounts.
|
||||
|
||||
The per-user Ayrshare profile key is provisioned and persisted as a
|
||||
standard ``is_managed=True`` credential by
|
||||
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
|
||||
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
|
||||
hosted social-linking page; all profile lifecycle logic lives with the
|
||||
managed provider.
|
||||
"""
|
||||
if not ayrshare_settings_available():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
Generate an SSO URL for Ayrshare social media integration.
|
||||
|
||||
Returns:
|
||||
dict: Contains the SSO URL for Ayrshare integration
|
||||
"""
|
||||
try:
|
||||
client = AyrshareClient()
|
||||
except MissingConfigError:
|
||||
@@ -1138,63 +778,66 @@ async def get_ayrshare_sso_url(
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
|
||||
# On-demand provisioning: AyrshareManagedProvider opts out of the
|
||||
# credentials sweep (profile quota is per-user subscription-bound). This
|
||||
# endpoint is the only trigger that provisions a profile — one Ayrshare
|
||||
# profile per user who actually opens the connect flow, not one per
|
||||
# every authenticated user.
|
||||
provisioned = await ensure_managed_credential(
|
||||
user_id, creds_manager.store, AyrshareManagedProvider()
|
||||
)
|
||||
if not provisioned:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
# Ayrshare profile key is stored in the credentials store
|
||||
# It is generated when creating a new profile, if there is no profile key,
|
||||
# we create a new profile and store the profile key in the credentials store
|
||||
|
||||
ayrshare_creds = [
|
||||
c
|
||||
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
|
||||
if c.is_managed and isinstance(c, APIKeyCredentials)
|
||||
]
|
||||
if not ayrshare_creds:
|
||||
logger.error(
|
||||
"Ayrshare credential provisioning did not produce a credential "
|
||||
"for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
|
||||
user_integrations: UserIntegrations = await get_user_integrations(user_id)
|
||||
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
if not profile_key:
|
||||
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
|
||||
try:
|
||||
profile = await client.create_profile(
|
||||
title=f"User {user_id}", messaging_active=True
|
||||
)
|
||||
profile_key = profile.profileKey
|
||||
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to create Ayrshare profile",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
|
||||
|
||||
profile_key_str = (
|
||||
profile_key.get_secret_value()
|
||||
if isinstance(profile_key, SecretStr)
|
||||
else str(profile_key)
|
||||
)
|
||||
|
||||
private_key = settings.secrets.ayrshare_jwt_key
|
||||
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
|
||||
# Ayrshare JWT expiry is 2880 minutes (48 hours)
|
||||
max_expiry_minutes = 2880
|
||||
try:
|
||||
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
|
||||
jwt_response = await client.generate_jwt(
|
||||
private_key=private_key,
|
||||
profile_key=profile_key_str,
|
||||
# `allowed_social` is the set of networks the Ayrshare-hosted
|
||||
# social-linking page will *offer* the user to connect. Blocks
|
||||
# exist for more platforms than are listed here; the list is
|
||||
# deliberately narrower so the rollout can verify each network
|
||||
# end-to-end before widening the user-visible surface. Keep
|
||||
# in sync with tested platforms — extend as each is verified
|
||||
# against the block + Ayrshare's network-specific quirks.
|
||||
allowed_social=[
|
||||
# NOTE: We are enabling platforms one at a time
|
||||
# to speed up the development process
|
||||
# SocialPlatform.FACEBOOK,
|
||||
SocialPlatform.TWITTER,
|
||||
SocialPlatform.LINKEDIN,
|
||||
SocialPlatform.INSTAGRAM,
|
||||
SocialPlatform.YOUTUBE,
|
||||
# SocialPlatform.REDDIT,
|
||||
# SocialPlatform.TELEGRAM,
|
||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
||||
# SocialPlatform.PINTEREST,
|
||||
SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.BLUESKY,
|
||||
# SocialPlatform.SNAPCHAT,
|
||||
# SocialPlatform.THREADS,
|
||||
],
|
||||
expires_in=max_expiry_minutes,
|
||||
verify=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
|
||||
)
|
||||
|
||||
@@ -1,748 +0,0 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1", store)
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
|
||||
class TestGetPickerToken:
|
||||
"""POST /{provider}/credentials/{cred_id}/picker-token must:
|
||||
1. Return the access token for OAuth2 creds the caller owns.
|
||||
2. 404 for non-owned, non-existent, or wrong-provider creds.
|
||||
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
|
||||
4. 404 for SDK default creds (same hardening as get_credential).
|
||||
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
|
||||
existing meta-only endpoint must still strip secrets even after
|
||||
this picker-token endpoint exists."""
|
||||
|
||||
def test_oauth2_owner_gets_access_token(self):
|
||||
# Use a Google cred with a drive.file scope — only picker-eligible
|
||||
# (provider, scope) pairs can mint a token. GitHub-style creds are
|
||||
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
|
||||
cred = _make_oauth2_cred(
|
||||
cred_id="cred-gdrive",
|
||||
provider="google",
|
||||
)
|
||||
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-gdrive/picker-token")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# The whole point of this endpoint: the access token IS returned here.
|
||||
assert data["access_token"] == "ghp_secret_token"
|
||||
# Only the two declared fields come back — nothing else leaks.
|
||||
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
|
||||
|
||||
def test_non_picker_provider_rejected_as_400(self):
|
||||
"""Provider allowlist: even with a valid OAuth2 credential, a
|
||||
non-picker provider (GitHub, etc.) cannot mint a picker token.
|
||||
Stops this endpoint from being used as a generic bearer-token
|
||||
extraction path for any stored OAuth cred under the same user."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "not available for provider" in resp.json()["detail"]
|
||||
assert "ghp_secret_token" not in str(resp.json())
|
||||
|
||||
def test_google_oauth_without_drive_scope_rejected(self):
|
||||
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
|
||||
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
|
||||
Forces the frontend to reconnect with a Drive scope before the
|
||||
picker is available."""
|
||||
cred = _make_oauth2_cred(provider="google")
|
||||
cred.scopes = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "picker" in resp.json()["detail"].lower()
|
||||
|
||||
def test_api_key_credential_rejected_as_400(self):
|
||||
cred = _make_api_key_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-123/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
# API keys must not silently fall through to a 200 response of some
|
||||
# other shape — the client should see a clear shape rejection.
|
||||
body = str(resp.json())
|
||||
assert "sk-secret-key-value" not in body
|
||||
|
||||
def test_user_password_credential_rejected_as_400(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-789/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = str(resp.json())
|
||||
assert "s3cret-pass" not in body
|
||||
assert "admin" not in body
|
||||
|
||||
def test_host_scoped_credential_rejected_as_400(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-host/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "top-secret" not in str(resp.json())
|
||||
|
||||
def test_missing_credential_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=None)
|
||||
resp = client.post("/github/credentials/nonexistent/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_wrong_provider_returns_404(self):
|
||||
"""Symmetric with get_credential: provider mismatch is a generic
|
||||
404, not a 400, so we don't leak existence of a credential the
|
||||
caller doesn't own on that provider."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_sdk_default_returns_404(self):
|
||||
"""SDK defaults are invisible to the user-facing API — picker-token
|
||||
must not mint a token for them either."""
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.post("/openai/credentials/openai-default/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_oauth2_without_access_token_returns_400(self):
|
||||
"""A stored OAuth2 cred whose access_token is missing can't satisfy
|
||||
a picker init. Surface a clear reconnect instruction rather than
|
||||
returning an empty string."""
|
||||
cred = _make_oauth2_cred()
|
||||
# Simulate a cred that lost its access token
|
||||
object.__setattr__(cred, "access_token", None)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "reconnect" in resp.json()["detail"].lower()
|
||||
|
||||
def test_meta_only_endpoint_still_strips_access_token(self):
|
||||
"""Regression guard for the coexistence contract: the new
|
||||
picker-token endpoint must NOT accidentally leak the token through
|
||||
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
|
||||
covers this more broadly; this is a fast sanity check co-located
|
||||
with the new endpoint's tests."""
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "access_token" not in body
|
||||
assert "refresh_token" not in body
|
||||
assert "ghp_secret_token" not in str(body)
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.api.features.library.db import _fetch_schedule_info
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent.
|
||||
|
||||
Uses a create-then-catch-UniqueViolationError-then-update pattern on
|
||||
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
|
||||
This is more robust than ``upsert`` because Prisma's upsert atomicity
|
||||
guarantees are not well-documented for all versions.
|
||||
"""
|
||||
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
|
||||
_include = library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Already exists — update to restore if previously soft-deleted/archived
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
if added_agent is None:
|
||||
raise NotFoundError(
|
||||
f"LibraryAgent for graph #{graph_model.id} "
|
||||
f"v{graph_model.version} not found after UniqueViolationError"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
|
||||
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
|
||||
@@ -1,88 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.errors
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"""When no matching LibraryAgent exists, create inserts a new one."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
created_agent = MagicMock(name="CreatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
assert create_data["User"] == {"connect": {"id": "user-id"}}
|
||||
assert create_data["AgentGraph"] == {
|
||||
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
|
||||
}
|
||||
assert create_data["isCreatedByUser"] is False
|
||||
assert create_data["useGraphIsActiveVersion"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"""UniqueViolationError on create falls back to update."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
updated_agent = MagicMock(name="UpdatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
MagicMock(), message="unique constraint"
|
||||
)
|
||||
)
|
||||
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "user-id",
|
||||
"agentGraphId": "graph-id",
|
||||
"agentGraphVersion": 2,
|
||||
}
|
||||
}
|
||||
update_data = update_call.kwargs["data"]
|
||||
assert update_data["isDeleted"] is False
|
||||
assert update_data["isArchived"] is False
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
@@ -44,65 +43,6 @@ config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||
"""Fetch execution counts per graph in a single batched query."""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": {"in": graph_ids},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
return {
|
||||
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_schedule_info(
|
||||
user_id: str, graph_id: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Fetch a map of graph_id → earliest next_run_time ISO string.
|
||||
|
||||
When `graph_id` is provided, the scheduler query is narrowed to that graph,
|
||||
which is cheaper for single-agent lookups (detail page, post-update, etc.).
|
||||
"""
|
||||
try:
|
||||
scheduler_client = get_scheduler_client()
|
||||
schedules = await scheduler_client.get_execution_schedules(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
earliest: dict[str, tuple[datetime, str]] = {}
|
||||
for s in schedules:
|
||||
parsed = _parse_iso_datetime(s.next_run_time)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = earliest.get(s.graph_id)
|
||||
if current is None or parsed < current[0]:
|
||||
earliest[s.graph_id] = (parsed, s.next_run_time)
|
||||
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_iso_datetime(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse schedule next_run_time: %s", value)
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
user_id: str,
|
||||
search_term: Optional[str] = None,
|
||||
@@ -197,22 +137,12 @@ async def list_library_agents(
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -284,22 +214,12 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -365,12 +285,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
where={"userId": store_listing.owningUserId}
|
||||
)
|
||||
|
||||
schedule_info = (
|
||||
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
|
||||
if library_agent.AgentGraph
|
||||
else {}
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(
|
||||
library_agent,
|
||||
sub_graphs=(
|
||||
@@ -380,7 +294,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
),
|
||||
store_listing=store_listing,
|
||||
profile=profile,
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -416,25 +329,19 @@ async def get_library_agent_by_store_version_id(
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
if not agent:
|
||||
return None
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -448,10 +355,7 @@ async def get_library_agent_by_graph_id(
|
||||
assert agent.AgentGraph # make type checker happy
|
||||
# Include sub-graphs so we can make a full credentials input schema
|
||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(
|
||||
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -529,58 +433,32 @@ async def create_library_agent(
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).upsert(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_entry.id,
|
||||
"agentGraphVersion": graph_entry.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
"update": {
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"useGraphIsActiveVersion": True,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
}
|
||||
},
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
@@ -593,11 +471,7 @@ async def create_library_agent(
|
||||
for agent, graph in zip(library_agents, graph_entries):
|
||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in library_agents
|
||||
]
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
||||
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
@@ -659,8 +533,7 @@ async def update_agent_version_in_library(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
|
||||
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
@@ -709,9 +582,7 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -743,7 +614,6 @@ async def update_library_agent_version_and_settings(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
builder_chat_session_id=library.settings.builder_chat_session_id,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await update_library_agent(
|
||||
@@ -948,38 +818,92 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -1566,11 +1490,7 @@ async def bulk_move_agents_to_folder(
|
||||
),
|
||||
)
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in agents
|
||||
]
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
@@ -1804,7 +1724,7 @@ async def create_preset_from_graph_execution(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.regular_credentials_inputs) > 0:
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -65,11 +63,6 @@ async def test_get_library_agents(mocker):
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
@@ -92,6 +85,10 @@ async def test_get_library_agents(mocker):
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
@@ -146,18 +143,15 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -176,27 +170,37 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
assert create_call_args is not None
|
||||
|
||||
# Verify the create data structure
|
||||
create_data = create_call_args.kwargs["data"]
|
||||
expected_create = {
|
||||
# Verify the main structure
|
||||
expected_data = {
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
}
|
||||
for key, value in expected_create.items():
|
||||
assert create_data[key] == value
|
||||
|
||||
actual_data = create_call_args[1]["data"]
|
||||
# Check that all expected fields are present
|
||||
for key, value in expected_data.items():
|
||||
assert actual_data[key] == value
|
||||
|
||||
# Check that settings field is present and is a SafeJson object
|
||||
assert "settings" in create_data
|
||||
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
|
||||
assert "settings" in actual_data
|
||||
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
|
||||
|
||||
# Check include parameter
|
||||
assert create_call_args.kwargs["include"] == library_agent_include(
|
||||
assert create_call_args[1]["include"] == library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
@@ -220,274 +224,3 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_library_agent_uses_upsert():
|
||||
"""create_library_agent should use upsert (not create) to handle duplicates."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.version = 1
|
||||
mock_graph.user_id = "user-1"
|
||||
mock_graph.nodes = []
|
||||
mock_graph.sub_graphs = []
|
||||
|
||||
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_tx():
|
||||
yield None
|
||||
|
||||
with (
|
||||
patch("backend.api.features.library.db.transaction", fake_tx),
|
||||
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library.db.add_generated_agent_image",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
|
||||
|
||||
result = await db.create_library_agent(mock_graph, "user-1")
|
||||
|
||||
assert len(result) == 1
|
||||
upsert_call = mock_prisma.return_value.upsert.call_args
|
||||
assert upsert_call is not None
|
||||
# Verify the upsert where clause uses the composite unique key
|
||||
where = upsert_call.kwargs["where"]
|
||||
assert "userId_agentGraphId_agentGraphVersion" in where
|
||||
# Verify the upsert data has both create and update branches
|
||||
data = upsert_call.kwargs["data"]
|
||||
assert "create" in data
|
||||
assert "update" in data
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_favorite_library_agents(mocker):
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="fav1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-fav",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=True,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-fav",
|
||||
version=1,
|
||||
name="Favorite Agent",
|
||||
description="My Favorite",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||
)
|
||||
|
||||
result = await db.list_favorite_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "fav1"
|
||||
assert result.agents[0].name == "Favorite Agent"
|
||||
assert result.agents[0].graph_id == "agent-fav"
|
||||
assert result.pagination.total_items == 1
|
||||
assert result.pagination.total_pages == 1
|
||||
assert result.pagination.current_page == 1
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua-bad",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-bad",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-bad",
|
||||
version=1,
|
||||
name="Bad Agent",
|
||||
description="",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 0
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_empty_graph_ids():
|
||||
result = await db._fetch_execution_counts("user-1", [])
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||
return_value=[
|
||||
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||
]
|
||||
)
|
||||
|
||||
result = await db._fetch_execution_counts(
|
||||
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||
)
|
||||
|
||||
assert result == {"graph-1": 5, "graph-2": 2}
|
||||
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
|
||||
@@ -214,14 +214,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
is_scheduled: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Whether this agent has active execution schedules",
|
||||
)
|
||||
next_scheduled_run: str | None = pydantic.Field(
|
||||
default=None,
|
||||
description="ISO 8601 timestamp of the next scheduled run, if any",
|
||||
)
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@@ -231,8 +223,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||
profile: Optional[prisma.models.Profile] = None,
|
||||
execution_count_override: Optional[int] = None,
|
||||
schedule_info: Optional[dict[str, str]] = None,
|
||||
) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
@@ -268,14 +258,10 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = (
|
||||
execution_count_override
|
||||
if execution_count_override is not None
|
||||
else len(executions)
|
||||
)
|
||||
execution_count = len(executions)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if executions and execution_count > 0:
|
||||
if execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
@@ -368,10 +354,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
|
||||
next_scheduled_run=(
|
||||
schedule_info.get(agent.agentGraphId) if schedule_info else None
|
||||
),
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -1,66 +1,11 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
def _make_library_agent(
|
||||
*,
|
||||
graph_id: str = "g1",
|
||||
executions: list | None = None,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
return prisma.models.LibraryAgent(
|
||||
id="la1",
|
||||
userId="u1",
|
||||
agentGraphId=graph_id,
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
name="Agent",
|
||||
description="Desc",
|
||||
userId="u1",
|
||||
isActive=True,
|
||||
createdAt=datetime.datetime.now(),
|
||||
Executions=executions,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_execution_count_override_covers_success_rate():
|
||||
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec-1",
|
||||
agentGraphId="g1",
|
||||
agentGraphVersion=1,
|
||||
userId="u1",
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent = _make_library_agent(executions=[exec1])
|
||||
|
||||
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||
|
||||
assert result.execution_count == 1
|
||||
assert result.success_rate is not None
|
||||
assert result.success_rate == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db(test_user_id: str):
|
||||
# Create mock DB agent
|
||||
|
||||
@@ -12,7 +12,6 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -59,27 +58,14 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_onboarding_profile_success(mocker):
|
||||
mock_extract = mocker.patch(
|
||||
"backend.api.features.v1.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.v1.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
|
||||
user_name="John",
|
||||
user_role="Founder/CEO",
|
||||
pain_points=["Finding leads"],
|
||||
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
|
||||
)
|
||||
mock_upsert.return_value = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={
|
||||
"user_name": "John",
|
||||
"user_role": "Founder/CEO",
|
||||
"pain_points": ["Finding leads", "Email & outreach"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once()
|
||||
|
||||
|
||||
def test_onboarding_profile_missing_fields():
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={"user_name": "John"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -1 +0,0 @@
|
||||
"""Platform bot linking — user-facing REST routes."""
|
||||
@@ -1,158 +0,0 @@
|
||||
"""User-facing platform_linking REST routes (JWT auth)."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Path, Security
|
||||
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
ConfirmUserLinkResponse,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenInfoResponse,
|
||||
PlatformLinkInfo,
|
||||
PlatformUserLinkInfo,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, NotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
if isinstance(exc, NotAuthorizedError):
|
||||
return HTTPException(status_code=403, detail=str(exc))
|
||||
if isinstance(exc, LinkAlreadyExistsError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
if isinstance(exc, LinkTokenExpiredError):
|
||||
return HTTPException(status_code=410, detail=str(exc))
|
||||
if isinstance(exc, LinkFlowMismatchError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
return HTTPException(status_code=500, detail="Internal error.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/info",
|
||||
response_model=LinkTokenInfoResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Get display info for a link token",
|
||||
)
|
||||
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
|
||||
try:
|
||||
return await platform_linking_db().get_link_token_info(token)
|
||||
except (NotFoundError, LinkTokenExpiredError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a SERVER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_server_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user-tokens/{token}/confirm",
|
||||
response_model=ConfirmUserLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a USER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_user_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmUserLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_user_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform servers linked to the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
return await platform_linking_db().list_server_links(user_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user-links",
|
||||
response_model=list[PlatformUserLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all DM links for the authenticated user",
|
||||
)
|
||||
async def list_my_user_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformUserLinkInfo]:
|
||||
return await platform_linking_db().list_user_links(user_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform server",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_server_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/user-links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a DM / user link",
|
||||
)
|
||||
async def delete_user_link_route(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_user_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Route tests: domain exceptions → HTTPException status codes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def _db_mock(**method_configs):
|
||||
"""Return a mock of the accessor's return value with the given AsyncMocks."""
|
||||
db = MagicMock()
|
||||
for name, mock in method_configs.items():
|
||||
setattr(db, name, mock)
|
||||
return db
|
||||
|
||||
|
||||
class TestTokenInfoRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_maps_to_410(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 410
|
||||
|
||||
|
||||
class TestConfirmLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_link_token
|
||||
|
||||
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestConfirmUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_user_link_token
|
||||
|
||||
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_user_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestDeleteLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
class TestDeleteUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(
|
||||
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
# ── Adversarial: malformed token path params ──────────────────────────
|
||||
|
||||
|
||||
class TestAdversarialTokenPath:
|
||||
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_rejects_token_with_special_chars(self, client):
|
||||
response = client.get("/api/platform-linking/tokens/bad%24token/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_rejects_token_with_path_traversal(self, client):
|
||||
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
|
||||
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
|
||||
assert response.status_code in (
|
||||
404,
|
||||
422,
|
||||
), f"path-traversal probe {probe!r} returned {response.status_code}"
|
||||
|
||||
def test_rejects_token_too_long(self, client):
|
||||
long_token = "a" * 65
|
||||
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_accepts_token_at_max_length(self, client):
|
||||
token = "a" * 64
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get(f"/api/platform-linking/tokens/{token}/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_accepts_urlsafe_b64_token_shape(self, client):
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_confirm_rejects_malformed_token(self, client):
|
||||
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAdversarialDeleteLinkId:
|
||||
"""DELETE link_id has no regex — ensure weird values are handled via
|
||||
NotFoundError (no crash, no cross-user leak)."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_weird_link_id_returns_404(self, client):
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
|
||||
response = client.delete(f"/api/platform-linking/links/{link_id}")
|
||||
assert response.status_code in (404, 405)
|
||||
@@ -391,11 +391,6 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -416,57 +411,6 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
|
||||
@@ -189,7 +189,6 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,7 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
from typing import Annotated, Any, Sequence, get_args
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -25,12 +24,10 @@ from fastapi import (
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import SubscriptionTier
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
from backend.api.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
@@ -50,24 +47,12 @@ from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
PendingChangeUnknown,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
sync_subscription_schedule_from_stripe,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
@@ -78,17 +63,12 @@ from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
complete_onboarding_step,
|
||||
complete_re_run_agent,
|
||||
format_onboarding_for_extraction,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.tally import extract_business_understanding
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
@@ -97,7 +77,6 @@ from backend.data.user import (
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
)
|
||||
from backend.data.workspace import get_workspace_file_by_id
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
@@ -303,33 +282,35 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingProfileRequest(pydantic.BaseModel):
|
||||
"""Request body for onboarding profile submission."""
|
||||
|
||||
user_name: str = pydantic.Field(min_length=1, max_length=100)
|
||||
user_role: str = pydantic.Field(min_length=1, max_length=100)
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding completion check."""
|
||||
"""Response for onboarding status check."""
|
||||
|
||||
is_completed: bool
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/completed",
|
||||
summary="Check if onboarding is completed",
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
response_model=OnboardingStatusResponse,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_completed(
|
||||
async def is_onboarding_enabled(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
return OnboardingStatusResponse(
|
||||
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -344,38 +325,6 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await reset_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/onboarding/profile",
|
||||
summary="Submit onboarding profile",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def submit_onboarding_profile(
|
||||
data: OnboardingProfileRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
formatted = format_onboarding_for_extraction(
|
||||
user_name=data.user_name,
|
||||
user_role=data.user_role,
|
||||
pain_points=data.pain_points,
|
||||
)
|
||||
|
||||
try:
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
except Exception:
|
||||
understanding_input = BusinessUnderstandingInput.model_construct()
|
||||
|
||||
# Ensure the direct fields are set even if LLM missed them
|
||||
understanding_input.user_name = data.user_name
|
||||
understanding_input.user_role = data.user_role
|
||||
if not understanding_input.pain_points:
|
||||
understanding_input.pain_points = data.pain_points
|
||||
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
@@ -643,11 +592,6 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -662,27 +606,14 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
|
||||
try:
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
)
|
||||
return "Auto top-up settings updated"
|
||||
|
||||
|
||||
@@ -698,433 +629,41 @@ async def get_user_auto_top_up(
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Populated only when POST /credits/subscription starts a Stripe Checkout"
|
||||
" Session (BASIC → paid upgrade). Empty string in all other branches —"
|
||||
" the client redirects to this URL when non-empty."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/subscription",
|
||||
summary="Get subscription tier, current cost, and all tier costs",
|
||||
operation_id="getSubscriptionStatus",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
|
||||
priceable_tiers = [
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
]
|
||||
price_ids = await asyncio.gather(
|
||||
*[get_subscription_price_id(t) for t in priceable_tiers]
|
||||
)
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
|
||||
tier_costs: dict[str, int] = {}
|
||||
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
|
||||
if pid:
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
except (stripe.StripeError, PendingChangeUnknown):
|
||||
# Swallow Stripe-side failures (rate limits, transient network) AND
|
||||
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
|
||||
# propagate past the cache so the next request retries fresh instead
|
||||
# of serving a stale None for the TTL window. Let real bugs (KeyError,
|
||||
# AttributeError, etc.) propagate so they surface in Sentry.
|
||||
logger.exception(
|
||||
"get_subscription_status: failed to resolve pending change for user %s",
|
||||
user_id,
|
||||
)
|
||||
pending = None
|
||||
|
||||
response = SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum in (
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
):
|
||||
response.pending_tier = pending_tier_enum.value
|
||||
response.pending_tier_effective_at = pending_effective_at
|
||||
return response
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/subscription",
|
||||
summary="Update subscription tier or start a Stripe Checkout session",
|
||||
operation_id="updateSubscriptionTier",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_subscription_tier(
|
||||
request: SubscriptionTierRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
|
||||
tier = SubscriptionTier(request.tier)
|
||||
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (
|
||||
user.subscription_tier or SubscriptionTier.BASIC
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||
)
|
||||
|
||||
# Same-tier request = "stay on my current tier" = cancel any pending
|
||||
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error releasing pending subscription change for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel the pending subscription change right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
current_tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
target_price_id, current_tier_price_id = await asyncio.gather(
|
||||
get_subscription_price_id(tier),
|
||||
get_subscription_price_id(current_tier),
|
||||
)
|
||||
|
||||
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
|
||||
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
|
||||
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
|
||||
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
|
||||
# through to the modify/checkout flow below.
|
||||
if tier == SubscriptionTier.BASIC and target_price_id is None:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
if not payment_enabled:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# Target has no LD price — not provisionable (matches the GET hiding).
|
||||
if target_price_id is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# User has an active Stripe subscription (current tier has an LD price):
|
||||
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
|
||||
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
|
||||
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
|
||||
# sub must still go through Checkout so they set up payment.
|
||||
if current_tier_price_id is not None:
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
if current_tier != SubscriptionTier.BASIC:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# No active Stripe subscription → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
status = await get_subscription_status(user_id)
|
||||
status.url = url
|
||||
return status
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
):
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
|
||||
if event_type in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
|
||||
# `subscription_schedule.updated` is deliberately omitted: our own
|
||||
# `SubscriptionSchedule.create` + `.modify` calls in
|
||||
# `_schedule_downgrade_at_period_end` would fire that event right back at us
|
||||
# and loop redundant traffic through this handler. We only care about state
|
||||
# transitions (released / completed); phase advance to the new price is
|
||||
# already covered by `customer.subscription.updated`.
|
||||
if event_type in (
|
||||
"subscription_schedule.released",
|
||||
"subscription_schedule.completed",
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -1426,16 +965,14 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -1445,7 +982,6 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
@@ -1708,10 +1244,6 @@ async def enable_execution_sharing(
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Remove stale allowlist records before updating the token — prevents a
|
||||
# window where old records + new token could coexist.
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1721,14 +1253,6 @@ async def enable_execution_sharing(
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create allowlist of workspace files referenced in outputs
|
||||
await execution_db.create_shared_execution_files(
|
||||
execution_id=graph_exec_id,
|
||||
share_token=share_token,
|
||||
user_id=user_id,
|
||||
outputs=execution.outputs,
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
@@ -1754,9 +1278,6 @@ async def disable_execution_sharing(
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove shared file allowlist records
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1782,43 +1303,6 @@ async def get_shared_execution(
|
||||
return execution
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/public/shared/{share_token}/files/{file_id}/download",
|
||||
summary="Download a file from a shared execution",
|
||||
operation_id="download_shared_file",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def download_shared_file(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
file_id: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> Response:
|
||||
"""Download a workspace file from a shared execution (no auth required).
|
||||
|
||||
Validates that the file was explicitly exposed when sharing was enabled.
|
||||
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
|
||||
"""
|
||||
# Single-query validation against the allowlist
|
||||
execution_id = await execution_db.get_shared_execution_file(
|
||||
share_token=share_token, file_id=file_id
|
||||
)
|
||||
if not execution_id:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
# Look up the actual file (no workspace scoping needed — the allowlist
|
||||
# already validated that this file belongs to the shared execution)
|
||||
file = await get_workspace_file_by_id(file_id)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
return await create_file_download_response(file, inline=True)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tests for the public shared file download endpoint."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import Response
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(v1_router, prefix="/api")
|
||||
|
||||
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
|
||||
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
|
||||
|
||||
def _make_workspace_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": VALID_FILE_ID,
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "image.png",
|
||||
"path": "/image.png",
|
||||
"storage_path": "local://uploads/image.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 4,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _mock_download_response(**kwargs):
|
||||
"""Return an AsyncMock that resolves to a Response with inline disposition."""
|
||||
|
||||
async def _handler(file, *, inline=False):
|
||||
return Response(
|
||||
content=b"\x89PNG",
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Content-Disposition": (
|
||||
'inline; filename="image.png"'
|
||||
if inline
|
||||
else 'attachment; filename="image.png"'
|
||||
),
|
||||
"Content-Length": "4",
|
||||
},
|
||||
)
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
class TestDownloadSharedFile:
|
||||
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _client(self):
|
||||
self.client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_valid_token_and_file_returns_inline_content(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_workspace_file(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.create_file_download_response",
|
||||
side_effect=_mock_download_response(),
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"\x89PNG"
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_invalid_token_format_returns_422(self):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_token_not_in_allowlist_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_file_missing_from_workspace_returns_404(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_uniform_404_prevents_enumeration(self):
|
||||
"""Both failure modes produce identical 404 — no information leak."""
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
resp_no_allow = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp_no_file = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert resp_no_allow.status_code == 404
|
||||
assert resp_no_file.status_code == 404
|
||||
assert resp_no_allow.json() == resp_no_file.json()
|
||||
@@ -12,7 +12,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -29,9 +29,7 @@ from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(
|
||||
filename: str, disposition: str = "attachment"
|
||||
) -> str:
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
@@ -46,11 +44,11 @@ def _sanitize_filename_for_header(
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'{disposition}; filename="{sanitized}"'
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"{disposition}; filename*=UTF-8''{encoded}"
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,26 +58,19 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(
|
||||
content: bytes, file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
disposition = _sanitize_filename_for_header(
|
||||
file.name, disposition="inline" if inline else "attachment"
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def create_file_download_response(
|
||||
file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -91,7 +82,7 @@ async def create_file_download_response(
|
||||
# For local storage, stream the file directly
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
@@ -99,7 +90,7 @@ async def create_file_download_response(
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
@@ -111,7 +102,7 @@ async def create_file_download_response(
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
@@ -140,26 +131,9 @@ class StorageUsageResponse(BaseModel):
|
||||
file_count: int
|
||||
|
||||
|
||||
class WorkspaceFileItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
files: list[WorkspaceFileItem]
|
||||
offset: int = 0
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
operation_id="getWorkspaceDownloadFileById",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -178,13 +152,12 @@ async def download_file(
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await create_file_download_response(file)
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
operation_id="deleteWorkspaceFile",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -210,13 +183,11 @@ async def delete_workspace_file(
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
operation_id="uploadWorkspaceFile",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -224,9 +195,6 @@ async def upload_file(
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
# Empty-string session_id drops session scoping; normalize to None.
|
||||
session_id = session_id or None
|
||||
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
@@ -280,28 +248,15 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
try:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to soft-delete over-quota file {workspace_file.id} "
|
||||
f"in workspace {workspace.id}: {e}"
|
||||
)
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
@@ -323,7 +278,6 @@ async def upload_file(
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
operation_id="getWorkspaceStorageUsage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -344,57 +298,3 @@ async def get_storage_usage(
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files",
|
||||
summary="List workspace files",
|
||||
operation_id="listWorkspaceFiles",
|
||||
)
|
||||
async def list_workspace_files(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
session_id: str | None = Query(default=None),
|
||||
limit: int = Query(default=200, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
List files in the user's workspace.
|
||||
|
||||
When session_id is provided, only files for that session are returned.
|
||||
Otherwise, all files across sessions are listed. Results are paginated
|
||||
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
|
||||
"""
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Treat empty-string session_id the same as omitted — an empty value
|
||||
# would otherwise silently list files across every session instead of
|
||||
# scoping to one.
|
||||
session_id = session_id or None
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
include_all = session_id is None
|
||||
# Fetch one extra to compute has_more without a separate count query.
|
||||
files = await manager.list_files(
|
||||
limit=limit + 1,
|
||||
offset=offset,
|
||||
include_all_sessions=include_all,
|
||||
)
|
||||
has_more = len(files) > limit
|
||||
page = files[:limit]
|
||||
|
||||
return ListFilesResponse(
|
||||
files=[
|
||||
WorkspaceFileItem(
|
||||
id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
metadata=f.metadata or {},
|
||||
created_at=f.created_at.isoformat(),
|
||||
)
|
||||
for f in page
|
||||
],
|
||||
offset=offset,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
@@ -1,28 +1,48 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace.routes import router
|
||||
from backend.data.workspace import Workspace, WorkspaceFile
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from the REST app."""
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
@@ -33,201 +53,25 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
|
||||
return Workspace(
|
||||
id="ws-001",
|
||||
user_id=user_id,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"storage_path": "local://test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _make_file_mock(**overrides) -> MagicMock:
|
||||
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"metadata": {},
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock(spec=WorkspaceFile)
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
# -- list_workspace_files tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
|
||||
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["files"]) == 2
|
||||
assert data["has_more"] is False
|
||||
assert data["offset"] == 0
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_scopes_to_session_when_provided(
|
||||
mock_manager_cls, mock_get_workspace, test_user_id
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?session_id=sess-123")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["files"] == []
|
||||
assert data["has_more"] is False
|
||||
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=False
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["files"][0]["metadata"] == {}
|
||||
|
||||
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_instance.write_file.assert_called_once()
|
||||
call_kwargs = mock_instance.write_file.call_args
|
||||
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.side_effect = ValueError("File already exists at path")
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("dup.txt", b"content", "text/plain")},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
# -- Restored upload/download/delete security + invariant tests --
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
_MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-001",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
name="hello.txt",
|
||||
path="/sessions/sess-1/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker):
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -238,7 +82,7 @@ def test_upload_happy_path(mocker):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -252,7 +96,10 @@ def test_upload_happy_path(mocker):
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker):
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
@@ -262,11 +109,15 @@ def test_upload_exceeds_max_file_size(mocker):
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
@@ -277,22 +128,27 @@ def test_upload_storage_quota_exceeded(mocker):
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker):
|
||||
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -304,14 +160,17 @@ def test_upload_post_write_quota_race(mocker):
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker):
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -322,7 +181,7 @@ def test_upload_any_extension(mocker):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -332,13 +191,16 @@ def test_upload_any_extension(mocker):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -349,7 +211,7 @@ def test_upload_blocked_by_virus_scan(mocker):
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -357,14 +219,18 @@ def test_upload_blocked_by_virus_scan(mocker):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker):
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -375,7 +241,7 @@ def test_upload_file_without_extension(mocker):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -391,11 +257,14 @@ def test_upload_file_without_extension(mocker):
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker):
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -406,23 +275,28 @@ def test_upload_strips_path_components(mocker):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker):
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
@@ -433,11 +307,14 @@ def test_download_file_not_found(mocker):
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_file_success(mocker):
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
@@ -452,11 +329,11 @@ def test_delete_file_success(mocker):
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker):
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=_make_workspace(),
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
@@ -470,7 +347,7 @@ def test_delete_file_not_found(mocker):
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker):
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
@@ -480,341 +357,3 @@ def test_delete_file_no_workspace(mocker):
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"""write_file raises ValueError("File too large: …") → must map to 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "File too large" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"""Non-'File too large' ValueErrors from write_file stay as 409."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.text
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_true_when_limit_exceeded(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
# Backend was asked for limit+1=3, and returned exactly 3 items.
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt"),
|
||||
_make_file(id="f2", name="b.txt"),
|
||||
_make_file(id="f3", name="c.txt"),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is True
|
||||
assert len(data["files"]) == 2
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=3, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_false_when_exactly_page_size(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Exactly `limit` rows means we're on the last page — has_more=False."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
assert len(data["files"]) == 2
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?offset=50&limit=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["offset"] == 50
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
# -- _sanitize_filename_for_header tests --
|
||||
|
||||
|
||||
class TestSanitizeFilenameForHeader:
|
||||
def test_simple_ascii_attachment(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("report.pdf") == (
|
||||
'attachment; filename="report.pdf"'
|
||||
)
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
|
||||
'inline; filename="image.png"'
|
||||
)
|
||||
|
||||
def test_strips_cr_lf_null(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert "\x00" not in result
|
||||
assert 'filename="abcd.txt"' in result
|
||||
|
||||
def test_escapes_quotes(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header('file"name.txt')
|
||||
assert 'filename="file\\"name.txt"' in result
|
||||
|
||||
def test_header_injection_blocked(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
|
||||
# CR/LF stripped — the remaining text is safely inside the quoted value
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert result == 'attachment; filename="evil.txtX-Injected: true"'
|
||||
|
||||
def test_unicode_uses_rfc5987(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("日本語.pdf")
|
||||
assert "filename*=UTF-8''" in result
|
||||
assert "attachment" in result
|
||||
|
||||
def test_unicode_inline(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("图片.png", disposition="inline")
|
||||
assert result.startswith("inline; filename*=UTF-8''")
|
||||
|
||||
def test_empty_filename(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("")
|
||||
assert result == 'attachment; filename=""'
|
||||
|
||||
|
||||
# -- _create_streaming_response tests --
|
||||
|
||||
|
||||
class TestCreateStreamingResponse:
|
||||
def test_attachment_disposition_by_default(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="data.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(b"binary-data", file)
|
||||
assert (
|
||||
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
|
||||
)
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["Content-Length"] == "11"
|
||||
assert response.body == b"binary-data"
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="photo.png", mime_type="image/png")
|
||||
response = _create_streaming_response(b"\x89PNG", file, inline=True)
|
||||
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
|
||||
assert response.headers["Content-Type"] == "image/png"
|
||||
|
||||
def test_inline_sanitizes_filename(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
|
||||
response = _create_streaming_response(b"data", file, inline=True)
|
||||
assert "\r" not in response.headers["Content-Disposition"]
|
||||
assert "\n" not in response.headers["Content-Disposition"]
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_content_length_matches_body(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
content = b"x" * 1000
|
||||
file = _make_file(name="big.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(content, file)
|
||||
assert response.headers["Content-Length"] == "1000"
|
||||
|
||||
|
||||
# -- create_file_download_response tests --
|
||||
|
||||
|
||||
class TestCreateFileDownloadResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_returns_streaming_response(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"file contents"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/test.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"file contents"
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_inline(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"\x89PNG"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/photo.png",
|
||||
mime_type="image/png",
|
||||
name="photo.png",
|
||||
)
|
||||
response = await create_file_download_response(file, inline=True)
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_redirect(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = (
|
||||
"https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.pdf")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers["location"] == "https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_api_fallback_streams_directly(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = "/api/fallback"
|
||||
mock_storage.retrieve.return_value = b"fallback content"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"fallback content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.return_value = b"streamed"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"streamed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_total_failure_raises(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
with pytest.raises(RuntimeError, match="Also failed"):
|
||||
await create_file_download_response(file)
|
||||
|
||||
@@ -17,10 +17,7 @@ from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.diagnostics_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -32,7 +29,6 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
@@ -121,11 +117,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -219,22 +210,13 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -284,10 +266,12 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -322,26 +306,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.diagnostics_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.execution_analytics_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.rate_limit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -379,11 +348,6 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
@@ -557,11 +521,8 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -42,13 +42,11 @@ def main(**kwargs):
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.platform_linking.manager import PlatformLinkingManager
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
PlatformLinkingManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
|
||||
@@ -25,7 +25,6 @@ from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -44,7 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails
|
||||
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||
|
||||
from ..data.graph import Link
|
||||
|
||||
@@ -96,64 +95,27 @@ class BlockCategory(Enum):
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
# RUN : cost_amount credits per run.
|
||||
# BYTE : cost_amount credits per byte of input data.
|
||||
# SECOND : cost_amount credits per cost_divisor walltime seconds.
|
||||
# ITEMS : cost_amount credits per cost_divisor items (from stats).
|
||||
# COST_USD : cost_amount credits per USD of stats.provider_cost.
|
||||
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
|
||||
RUN = "run"
|
||||
BYTE = "byte"
|
||||
SECOND = "second"
|
||||
ITEMS = "items"
|
||||
COST_USD = "cost_usd"
|
||||
TOKENS = "tokens"
|
||||
|
||||
@property
|
||||
def is_dynamic(self) -> bool:
|
||||
"""Real charge is computed post-flight from stats.
|
||||
|
||||
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
|
||||
settle against stats via charge_reconciled_usage once the block runs.
|
||||
"""
|
||||
return self in _DYNAMIC_COST_TYPES
|
||||
|
||||
|
||||
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
|
||||
{
|
||||
BlockCostType.SECOND,
|
||||
BlockCostType.ITEMS,
|
||||
BlockCostType.COST_USD,
|
||||
BlockCostType.TOKENS,
|
||||
}
|
||||
)
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
|
||||
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
|
||||
# rate tables (per-model input/output/cache pricing) and ignores
|
||||
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
|
||||
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
|
||||
# "1 credit per 10 seconds of walltime".
|
||||
cost_divisor: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
cost_divisor: int = 1,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
cost_divisor=max(1, cost_divisor),
|
||||
**data,
|
||||
)
|
||||
|
||||
@@ -205,31 +167,9 @@ class BlockSchema(BaseModel):
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@classmethod
|
||||
def validate_data(
|
||||
cls,
|
||||
data: BlockInput,
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> str | None:
|
||||
schema = cls.jsonschema()
|
||||
if exclude_fields:
|
||||
# Drop the excluded fields from both the properties and the
|
||||
# ``required`` list so jsonschema doesn't flag them as missing.
|
||||
# Used by the dry-run path to skip credentials validation while
|
||||
# still validating the remaining block inputs.
|
||||
schema = {
|
||||
**schema,
|
||||
"properties": {
|
||||
k: v
|
||||
for k, v in schema.get("properties", {}).items()
|
||||
if k not in exclude_fields
|
||||
},
|
||||
"required": [
|
||||
r for r in schema.get("required", []) if r not in exclude_fields
|
||||
],
|
||||
}
|
||||
data = {k: v for k, v in data.items() if k not in exclude_fields}
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(
|
||||
schema=schema,
|
||||
schema=cls.jsonschema(),
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
|
||||
@@ -370,8 +310,6 @@ class BlockSchema(BaseModel):
|
||||
"credentials_provider": [config.get("provider", "google")],
|
||||
"credentials_types": [config.get("type", "oauth2")],
|
||||
"credentials_scopes": config.get("scopes"),
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": info["field_name"],
|
||||
}
|
||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||
auto_schema, by_alias=True
|
||||
@@ -517,6 +455,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
disabled: If the block is disabled, it will not be available for execution.
|
||||
static_output: Whether the output links of the block are static by default.
|
||||
"""
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
self.id = id
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
@@ -534,7 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -614,7 +554,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
@@ -758,90 +698,13 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
# Credential fields may be absent (LLM-built agents often skip
|
||||
# wiring them) or nullified earlier in the pipeline. Validate
|
||||
# the non-credential inputs against a schema with those fields
|
||||
# excluded — stripping only the data while keeping them in the
|
||||
# ``required`` list would falsely report ``'credentials' is a
|
||||
# required property``.
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
if error := self.input_schema.validate_data(
|
||||
input_data, exclude_fields=cred_field_names
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Ensure auto-credential kwargs are present before we hand off to
|
||||
# run(). A missing auto-credential means the upstream field (e.g.
|
||||
# a Google Drive picker) didn't embed a _credentials_id, or the
|
||||
# executor couldn't resolve it. Without this guard, run() would
|
||||
# crash with a TypeError (missing required kwarg) or an opaque
|
||||
# AttributeError deep inside the provider SDK.
|
||||
#
|
||||
# Only raise when the field is ALSO not populated in input_data.
|
||||
# ``_acquire_auto_credentials`` intentionally skips setting the
|
||||
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
|
||||
# (chained from upstream) or the field is missing from
|
||||
# ``input_data`` at prep time (connected from upstream block).
|
||||
# In both cases the upstream block is expected to populate the
|
||||
# field value by execute time; raising here would break the
|
||||
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
|
||||
# Dry-run skips because the executor intentionally runs blocks
|
||||
# without resolved creds for schema validation.
|
||||
if not is_dry_run:
|
||||
for (
|
||||
kwarg_name,
|
||||
info,
|
||||
) in self.input_schema.get_auto_credentials_fields().items():
|
||||
kwargs.setdefault(kwarg_name, None)
|
||||
if kwargs[kwarg_name] is not None:
|
||||
continue
|
||||
# Upstream-chained pattern: the field was populated by a
|
||||
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
|
||||
# output carries a resolved ``_credentials_id``.
|
||||
# ``_acquire_auto_credentials`` deliberately doesn't set
|
||||
# the kwarg in that case because the value isn't available
|
||||
# at prep time; the executor fills it in before we reach
|
||||
# ``_execute``. Trust it if the ``_credentials_id`` KEY
|
||||
# is present — its value may be explicitly ``None`` in
|
||||
# the chained case (see sentry thread
|
||||
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
|
||||
# falsely preempt run() for every valid chained graph
|
||||
# that ships ``_credentials_id=None`` in the picker
|
||||
# object. Mirror ``_acquire_auto_credentials``'s own
|
||||
# skip rule, which treats ``cred_id is None`` as a
|
||||
# chained-skip signal.
|
||||
field_name = info["field_name"]
|
||||
field_value = input_data.get(field_name)
|
||||
if isinstance(field_value, dict) and "_credentials_id" in field_value:
|
||||
continue
|
||||
raise BlockExecutionError(
|
||||
message=(
|
||||
f"Missing credentials for '{kwarg_name}'. "
|
||||
"Select a file via the picker (which carries "
|
||||
"its credentials), or connect credentials for "
|
||||
"this block."
|
||||
),
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,17 +49,11 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
return set(required_fields) - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -94,7 +88,6 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -156,25 +149,17 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
# Sub-graph already debited each of its own nodes; we
|
||||
# roll up its total so graph_stats.cost reflects the
|
||||
# full sub-graph spend.
|
||||
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,16 +4,11 @@ Shared configuration for all AgentMail blocks.
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
|
||||
# without any BLOCK_COSTS entry means they currently execute wallet-free.
|
||||
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
|
||||
# past billing. Revisit once AgentMail publishes usage-based pricing.
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -20,33 +19,6 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -190,29 +162,54 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
cache_read_token_count=response.cache_read_tokens,
|
||||
cache_creation_token_count=response.cache_creation_tokens,
|
||||
provider_cost=response.provider_cost,
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(
|
||||
response_text: str,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
provider_cost: float | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
provider_cost=provider_cost,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: cache tokens and provider_cost must be propagated to stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheTokenPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_propagated_to_stats(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""cache_read_tokens and cache_creation_tokens must be forwarded to
|
||||
NodeExecutionStats so that usage dashboards count cached tokens."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def spy_llm(**kwargs):
|
||||
return _mock_llm_response(
|
||||
"true",
|
||||
cache_read_tokens=7,
|
||||
cache_creation_tokens=3,
|
||||
provider_cost=0.0012,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(block, "llm_call", spy_llm)
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert block.execution_stats.cache_read_token_count == 7
|
||||
assert block.execution_stats.cache_creation_token_count == 3
|
||||
assert block.execution_stats.provider_cost == 0.0012
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,11 +218,6 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,9 +366,4 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -4,7 +4,6 @@ import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
@@ -16,14 +15,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -33,37 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
# Identifiers used when registering an AutoPilotBlock turn with the
|
||||
# stream registry — distinguishes block-originated turns from sub-session
|
||||
# or HTTP SSE turns in logs / observability.
|
||||
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
|
||||
_AUTOPILOT_TOOL_NAME = "autopilot_block"
|
||||
|
||||
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
|
||||
# enqueued turn's terminal event. Graph blocks run synchronously from
|
||||
# the caller's perspective so we wait effectively as long as needed; 6h
|
||||
# matches the previous abandoned-task cap and is much longer than any
|
||||
# legitimate AutoPilot turn.
|
||||
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
|
||||
|
||||
|
||||
class SubAgentRecursionError(BlockExecutionError):
|
||||
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
|
||||
|
||||
Inherits :class:`BlockExecutionError` — this is a known, handled
|
||||
runtime failure at the block level (caller nested AutoPilotBlocks
|
||||
beyond the configured limit). Surfaces with the block_name /
|
||||
block_id the block framework expects, instead of being wrapped in
|
||||
``BlockUnknownError``.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(
|
||||
message=message,
|
||||
block_name="AutoPilotBlock",
|
||||
block_id=AUTOPILOT_BLOCK_ID,
|
||||
)
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
@@ -135,65 +96,6 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -280,11 +182,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
from backend.copilot.model import create_chat_session
|
||||
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -294,17 +196,12 @@ class AutoPilotBlock(Block):
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot on the copilot_executor queue and aggregate the
|
||||
result.
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`run_copilot_turn_via_queue` — the shared
|
||||
primitive used by ``run_sub_session`` too — which creates the
|
||||
stream_registry meta record, enqueues the job, and waits on the
|
||||
Redis stream for the terminal event. Any available
|
||||
copilot_executor worker picks up the job, so this call survives
|
||||
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
@@ -312,54 +209,25 @@ class AutoPilotBlock(Block):
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.session_waiter import (
|
||||
run_copilot_turn_via_queue, # avoid circular import
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=effective_prompt,
|
||||
# Graph block execution is synchronous from the caller's
|
||||
# perspective — wait effectively as long as needed. The
|
||||
# SDK enforces its own idle-based timeout inside the
|
||||
# stream_registry pipeline.
|
||||
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
|
||||
permissions=effective_permissions,
|
||||
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=_AUTOPILOT_TOOL_NAME,
|
||||
user_id=user_id,
|
||||
)
|
||||
if outcome == "failed":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn failed — see the session's transcript"
|
||||
)
|
||||
if outcome == "running":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn did not complete within "
|
||||
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
|
||||
f"{session_id}"
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from the aggregated data.
|
||||
# When ``result.queued`` is True the prompt rode on an already-
|
||||
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
|
||||
# waited on the existing turn's stream); the aggregated result
|
||||
# is still valid, so the same rendering path applies.
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
@@ -368,7 +236,7 @@ class AutoPilotBlock(Block):
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -379,11 +247,11 @@ class AutoPilotBlock(Block):
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc.tool_call_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"input": tc.input,
|
||||
"output": tc.output,
|
||||
"success": tc.success,
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
@@ -403,8 +271,6 @@ class AutoPilotBlock(Block):
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -429,21 +295,11 @@ class AutoPilotBlock(Block):
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id,
|
||||
dry_run=input_data.dry_run or execution_context.dry_run,
|
||||
)
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
@@ -456,7 +312,6 @@ class AutoPilotBlock(Block):
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
@@ -468,41 +323,8 @@ class AutoPilotBlock(Block):
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except SubAgentRecursionError as exc:
|
||||
# Deliberate block — re-enqueueing would immediately hit the limit
|
||||
# again, so skip recovery and just surface the error.
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
# Recovery enqueue must happen BEFORE yielding "error": the block
|
||||
# framework (_base.execute) raises BlockExecutionError immediately
|
||||
# when it sees ("error", ...) and stops consuming the generator,
|
||||
# so any code after that yield is dead code in production.
|
||||
effective_prompt = input_data.prompt
|
||||
if input_data.system_context:
|
||||
effective_prompt = (
|
||||
f"[System Context: {input_data.system_context}]\n\n"
|
||||
f"{input_data.prompt}"
|
||||
)
|
||||
try:
|
||||
await _enqueue_for_recovery(
|
||||
sid,
|
||||
execution_context.user_id,
|
||||
effective_prompt,
|
||||
input_data.dry_run or execution_context.dry_run,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Task cancelled during recovery — still yield the error
|
||||
# so the session_id + error pair is visible before re-raising.
|
||||
yield "error", str(exc)
|
||||
raise
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"AutoPilot session %s: recovery enqueue raised unexpectedly",
|
||||
sid[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
@@ -530,13 +352,13 @@ def _check_recursion(
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
SubAgentRecursionError: If the current depth already meets or exceeds the limit.
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise SubAgentRecursionError(
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
@@ -552,126 +374,3 @@ def _reset_recursion(
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recovery helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _enqueue_for_recovery(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
"""Re-enqueue an orphaned sub-agent session so a fresh executor picks it up.
|
||||
|
||||
When ``execute_copilot`` raises an unexpected exception the sub-agent
|
||||
session is left with ``last_role=user`` and no active consumer — identical
|
||||
to the state that caused Toran's reports of silent sub-agents. Publishing
|
||||
the original prompt back to the copilot queue lets the executor service
|
||||
resume the session without manual intervention.
|
||||
|
||||
Skipped for dry-run sessions (no real consumers listen to the queue for
|
||||
simulated sessions). Any failure to publish is logged and swallowed so
|
||||
it never masks the original exception.
|
||||
"""
|
||||
if dry_run:
|
||||
return
|
||||
try:
|
||||
from backend.copilot.executor.utils import ( # avoid circular import
|
||||
enqueue_copilot_turn,
|
||||
)
|
||||
|
||||
await asyncio.wait_for(
|
||||
enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
turn_id=str(uuid.uuid4()),
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
logger.info("AutoPilot session %s enqueued for recovery", session_id[:12])
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"AutoPilot session %s: failed to enqueue for recovery",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Shared provider config for Ayrshare social-media blocks.
|
||||
|
||||
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
|
||||
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
|
||||
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
|
||||
and stored in the normal credentials list with ``is_managed=True``, so every
|
||||
Ayrshare block fits the standard credential flow:
|
||||
|
||||
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
|
||||
|
||||
``run_block`` / ``resolve_block_credentials`` take care of the rest.
|
||||
|
||||
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
|
||||
without the env-var-backed default credential that ``with_api_key()`` would
|
||||
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
|
||||
reach a block as a "profile key".
|
||||
"""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
ayrshare = ProviderBuilder("ayrshare").with_managed_api_key().build()
|
||||
@@ -1,18 +0,0 @@
|
||||
from backend.sdk import BlockCost, BlockCostType
|
||||
|
||||
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
|
||||
# prevent a single heavy user from absorbing the fixed cost and align with the
|
||||
# upload cost of each post variant.
|
||||
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
|
||||
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
|
||||
# override the base default to True; platforms that accept both (TikTok, etc.)
|
||||
# rely on the caller setting is_video explicitly for accurate billing.
|
||||
# First match wins in block_usage_cost, so list the video tier first.
|
||||
AYRSHARE_POST_COSTS = (
|
||||
BlockCost(
|
||||
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
|
||||
),
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user